未验证 提交 11997c8e 编写于 作者: Y Yibing Liu 提交者: GitHub

Cherry-pick #16206 to release 1.3 (#16218)

* Impl fp16 compute kernel for slice_op

* Use data() to replace mutable_data()

test=release/1.3
上级 38604b7c
......@@ -12,18 +12,138 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <thrust/device_vector.h>
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/slice_op.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
using platform::PADDLE_CUDA_NUM_THREADS;
template <size_t D>
__global__ void Padding(const paddle::platform::float16* d_out,
const int* out_dims, const int* in_dims,
const int* offsets, int64_t n,
paddle::platform::float16* d_in) {
int64_t out_idx = threadIdx.x + blockDim.x * blockIdx.x;
if (out_idx < n) {
int coords[D] = {0};
for (int i = D - 1; i >= 0; --i) {
coords[i] = out_idx % out_dims[i];
out_idx /= out_dims[i];
coords[i] += offsets[i];
}
int64_t in_idx = 0;
for (int i = 0; i < D - 1; ++i) {
in_idx += coords[i] * in_dims[i + 1];
}
in_idx += coords[D - 1];
d_in[in_idx] = d_out[out_idx];
}
}
template <>
class SliceGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>
: public framework::OpKernel<paddle::platform::float16> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* d_in = ctx.Output<framework::Tensor>(framework::GradVarName("Input"));
d_in->mutable_data<paddle::platform::float16>(ctx.GetPlace());
auto out_dims = d_out->dims();
auto in_dims = d_in->dims();
int rank = out_dims.size();
std::vector<int> offsets(rank, 0);
auto axes = ctx.Attr<std::vector<int>>("axes");
auto starts = ctx.Attr<std::vector<int>>("starts");
for (size_t i = 0; i < starts.size(); ++i) {
if (starts[i] < 0) {
starts[i] += in_dims[axes[i]];
}
offsets[axes[i]] = std::max(starts[i], 0);
}
math::SetConstant<paddle::platform::CUDADeviceContext,
paddle::platform::float16>
set_zero;
auto& dev_ctx =
ctx.template device_context<paddle::platform::CUDADeviceContext>();
set_zero(dev_ctx, d_in, static_cast<paddle::platform::float16>(0));
int64_t numel = d_out->numel();
dim3 blocks((numel - 1) / PADDLE_CUDA_NUM_THREADS + 1, 1, 1);
dim3 threads(PADDLE_CUDA_NUM_THREADS, 1, 1);
auto stream = ctx.cuda_device_context().stream();
auto out_shape = framework::vectorize2int(out_dims);
thrust::device_vector<int> out_dims_vec(out_shape.begin(), out_shape.end());
auto in_shape = framework::vectorize2int(in_dims);
thrust::device_vector<int> in_dims_vec(in_shape.begin(), in_shape.end());
thrust::device_vector<int> offsets_vec(offsets.begin(), offsets.end());
const int* out_dims_ptr = thrust::raw_pointer_cast(out_dims_vec.data());
const int* in_dims_ptr = thrust::raw_pointer_cast(in_dims_vec.data());
const int* offsets_ptr = thrust::raw_pointer_cast(offsets_vec.data());
switch (rank) {
case 1:
Padding<1><<<blocks, threads, 0, stream>>>(
d_out->data<paddle::platform::float16>(), out_dims_ptr, in_dims_ptr,
offsets_ptr, numel, d_in->data<paddle::platform::float16>());
break;
case 2:
Padding<2><<<blocks, threads, 0, stream>>>(
d_out->data<paddle::platform::float16>(), out_dims_ptr, in_dims_ptr,
offsets_ptr, numel, d_in->data<paddle::platform::float16>());
break;
case 3:
Padding<3><<<blocks, threads, 0, stream>>>(
d_out->data<paddle::platform::float16>(), out_dims_ptr, in_dims_ptr,
offsets_ptr, numel, d_in->data<paddle::platform::float16>());
break;
case 4:
Padding<4><<<blocks, threads, 0, stream>>>(
d_out->data<paddle::platform::float16>(), out_dims_ptr, in_dims_ptr,
offsets_ptr, numel, d_in->data<paddle::platform::float16>());
break;
case 5:
Padding<5><<<blocks, threads, 0, stream>>>(
d_out->data<paddle::platform::float16>(), out_dims_ptr, in_dims_ptr,
offsets_ptr, numel, d_in->data<paddle::platform::float16>());
break;
case 6:
Padding<6><<<blocks, threads, 0, stream>>>(
d_out->data<paddle::platform::float16>(), out_dims_ptr, in_dims_ptr,
offsets_ptr, numel, d_in->data<paddle::platform::float16>());
break;
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
slice, ops::SliceKernel<paddle::platform::CUDADeviceContext, float>,
ops::SliceKernel<paddle::platform::CUDADeviceContext, double>,
ops::SliceKernel<paddle::platform::CUDADeviceContext, int>,
ops::SliceKernel<paddle::platform::CUDADeviceContext, int64_t>);
ops::SliceKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::SliceKernel<paddle::platform::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
slice_grad,
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, plat::float16>);
......@@ -16,6 +16,7 @@ from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid.core as core
from op_test import OpTest
......@@ -63,5 +64,28 @@ class TestCase2(TestSliceOp):
self.out = self.input[-3:3, 0:100, :, 2:-1]
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFP16(TestSliceOp):
def config(self):
self.dtype = "float16"
self.input = np.random.random([3, 4, 5, 6]).astype(self.dtype)
self.starts = [-3, 0, 2]
self.ends = [3, 100, -1]
self.axes = [0, 1, 3]
self.out = self.input[-3:3, 0:100, :, 2:-1]
def test_check_output(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-5)
def test_check_grad_normal(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_grad_with_place(
place, ['Input'], 'Out', max_relative_error=0.006)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册