From 79a41a9ed6f3e2acb28b310907b91011ab766eac Mon Sep 17 00:00:00 2001 From: QingshuChen Date: Mon, 14 Dec 2020 14:38:20 +0800 Subject: [PATCH] support roi_align & affine_channel for kunlun (#29561) * support roi_align & affine_channel for kunlun * minor --- cmake/external/xpu.cmake | 2 +- .../fluid/operators/affine_channel_op_xpu.cc | 186 +++++++++++++++ paddle/fluid/operators/roi_align_op_xpu.cc | 211 ++++++++++++++---- .../xpu/test_affine_channel_op_xpu.py | 148 ++++++++++++ .../unittests/xpu/test_roi_align_op_xpu.py | 20 +- 5 files changed, 510 insertions(+), 57 deletions(-) create mode 100644 paddle/fluid/operators/affine_channel_op_xpu.cc create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_affine_channel_op_xpu.py diff --git a/cmake/external/xpu.cmake b/cmake/external/xpu.cmake index c9cf2572d1..75e0eb2e27 100644 --- a/cmake/external/xpu.cmake +++ b/cmake/external/xpu.cmake @@ -4,7 +4,7 @@ endif() INCLUDE(ExternalProject) SET(XPU_PROJECT "extern_xpu") -SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2020_12_07_cdfbf0c.tar.gz" CACHE STRING "" FORCE) +SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2020_12_11.tar.gz" CACHE STRING "" FORCE) SET(XPU_SOURCE_DIR "${THIRD_PARTY_PATH}/xpu") SET(XPU_DOWNLOAD_DIR "${XPU_SOURCE_DIR}/src/${XPU_PROJECT}") SET(XPU_INSTALL_DIR "${THIRD_PARTY_PATH}/install/xpu") diff --git a/paddle/fluid/operators/affine_channel_op_xpu.cc b/paddle/fluid/operators/affine_channel_op_xpu.cc new file mode 100644 index 0000000000..db3eedea7c --- /dev/null +++ b/paddle/fluid/operators/affine_channel_op_xpu.cc @@ -0,0 +1,186 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +Indicesou may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_XPU + +#include +#include +#include +#include "paddle/fluid/framework/data_layout.h" +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class AffineChannelXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* scale = ctx.Input("Scale"); + auto* bias = ctx.Input("Bias"); + + auto* y = ctx.Output("Out"); + y->mutable_data(ctx.GetPlace()); + + const framework::DataLayout layout = + framework::StringToDataLayout(ctx.Attr("data_layout")); + + auto dims = x->dims(); + int N = dims[0]; + int C = layout == framework::DataLayout::kNCHW ? dims[1] + : dims[dims.size() - 1]; + int HxW = x->numel() / N / C; + + auto* scale_d = scale->data(); + auto* bias_d = bias->data(); + + auto* x_d = x->data(); + auto* y_d = y->data(); + auto& dev_ctx = ctx.template device_context(); + std::vector x_shape; + std::vector b_shape; + if (layout == framework::DataLayout::kNCHW) { + x_shape.push_back(N); + x_shape.push_back(C); + x_shape.push_back(HxW); + b_shape.push_back(1); + b_shape.push_back(C); + b_shape.push_back(1); + } else { + x_shape.push_back(N * HxW); + x_shape.push_back(C); + b_shape.push_back(1); + b_shape.push_back(C); + } + int r = 0; + r = xpu::broadcast_mul(dev_ctx.x_context(), x_d, scale_d, y_d, x_shape, + b_shape); + PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS, + platform::errors::External( + "The broadcast_mul XPU OP return wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); + r = xpu::broadcast_add(dev_ctx.x_context(), y_d, bias_d, y_d, x_shape, + b_shape); + PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS, + platform::errors::External( + "The broadcast_add XPU OP return wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); + } +}; + +template +class AffineChannelGradXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* scale = ctx.Input("Scale"); + auto* dy = ctx.Input(framework::GradVarName("Out")); + + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dscale = + ctx.Output(framework::GradVarName("Scale")); + auto* dbias = ctx.Output(framework::GradVarName("Bias")); + + const framework::DataLayout layout = + framework::StringToDataLayout(ctx.Attr("data_layout")); + + auto dims = x->dims(); + int N = dims[0]; + int C = layout == framework::DataLayout::kNCHW ? dims[1] + : dims[dims.size() - 1]; + int HxW = x->numel() / N / C; + + auto* dy_d = dy->data(); + auto* scale_d = scale->data(); + + T* dx_d = dx ? dx->mutable_data(ctx.GetPlace()) : nullptr; + T* dscale_d = dscale ? dscale->mutable_data(ctx.GetPlace()) : nullptr; + T* dbias_d = dbias ? dbias->mutable_data(ctx.GetPlace()) : nullptr; + + auto& dev_ctx = ctx.template device_context(); + std::vector x_shape; + std::vector b_shape; + std::vector rdims; + if (layout == framework::DataLayout::kNCHW) { + x_shape.push_back(N); + x_shape.push_back(C); + x_shape.push_back(HxW); + b_shape.push_back(1); + b_shape.push_back(C); + b_shape.push_back(1); + rdims.push_back(0); + rdims.push_back(2); + } else { + x_shape.push_back(N * HxW); + x_shape.push_back(C); + b_shape.push_back(1); + b_shape.push_back(C); + rdims.push_back(0); + } + + int r = 0; + if (dscale_d && dbias_d) { + r = xpu::reduce_sum(dev_ctx.x_context(), dy_d, dbias_d, x_shape, + rdims); + PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS, + platform::errors::External( + "The reduce_sum XPU OP return wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); + T* tmp = nullptr; + r = xpu_malloc(reinterpret_cast(&tmp), dy->numel() * sizeof(T)); + PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS, + platform::errors::External("no enough memory in xpu")); + + r = xpu::mul(dev_ctx.x_context(), dy_d, x->data(), tmp, + dy->numel()); + PADDLE_ENFORCE_EQ( + r, xpu::Error_t::SUCCESS, + platform::errors::External("The mul XPU OP return wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); + r = xpu::reduce_sum(dev_ctx.x_context(), tmp, dscale_d, x_shape, + rdims); + PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS, + platform::errors::External( + "The reduce_sum XPU OP return wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); + if (dev_ctx.x_context()->xpu_stream) { + dev_ctx.Wait(); + } + xpu_free(tmp); + } + if (dx_d) { + r = xpu::broadcast_mul(dev_ctx.x_context(), dy_d, scale_d, dx_d, x_shape, + b_shape); + PADDLE_ENFORCE_EQ( + r, xpu::Error_t::SUCCESS, + platform::errors::External( + "The broadcast_mul XPU OP return wrong value[%d %s]", r, + XPUAPIErrorMsg[r])); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +using XPU = paddle::platform::XPUDeviceContext; + +REGISTER_OP_XPU_KERNEL(affine_channel, ops::AffineChannelXPUKernel); +REGISTER_OP_XPU_KERNEL(affine_channel_grad, + ops::AffineChannelGradXPUKernel); + +#endif diff --git a/paddle/fluid/operators/roi_align_op_xpu.cc b/paddle/fluid/operators/roi_align_op_xpu.cc index 699cc7b84a..f35cf06e5f 100644 --- a/paddle/fluid/operators/roi_align_op_xpu.cc +++ b/paddle/fluid/operators/roi_align_op_xpu.cc @@ -24,89 +24,202 @@ template class XPUROIAlignOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* in = ctx.Input("X"); - auto* rois = ctx.Input("ROIs"); - auto* out = ctx.Output("Out"); + auto* in = ctx.Input("X"); + auto* rois = ctx.Input("ROIs"); + auto* out = ctx.Output("Out"); + auto pooled_height = ctx.Attr("pooled_height"); auto pooled_width = ctx.Attr("pooled_width"); auto spatial_scale = ctx.Attr("spatial_scale"); auto sampling_ratio = ctx.Attr("sampling_ratio"); - auto& dev_ctx = ctx.template device_context(); + auto in_dims = in->dims(); int batch_size = in_dims[0]; int channels = in_dims[1]; int height = in_dims[2]; int width = in_dims[3]; + int rois_num = rois->dims()[0]; - const T* input_data = in->data(); - framework::Tensor _roi_batch_list; - _roi_batch_list.Resize({rois_num}); - int* rois_lod = _roi_batch_list.mutable_data(ctx.GetPlace()); - int rois_batch_size = 1; + if (rois_num == 0) return; + + Tensor roi_batch_id_list; + roi_batch_id_list.Resize({rois_num}); + auto cplace = platform::CPUPlace(); + int* roi_batch_id_data = roi_batch_id_list.mutable_data(cplace); + auto& dev_ctx = ctx.template device_context(); + auto xplace = BOOST_GET_CONST(platform::XPUPlace, ctx.GetPlace()); + int rois_batch_size = 0; + int* cpu_lod = nullptr; if (ctx.HasInput("RoisNum")) { - auto* rois_num_t = ctx.Input("RoisNum"); + auto* rois_num_t = ctx.Input("RoisNum"); rois_batch_size = rois_num_t->numel(); PADDLE_ENFORCE_EQ( rois_batch_size, batch_size, platform::errors::InvalidArgument( - "The batch size of rois and the batch size of images " - " must be the same. But received the batch size of rois is %d, " - "and the batch size of images is %d", + "The rois_batch_size and imgs " + "batch_size must be the same. But received rois_batch_size = %d, " + "batch_size = %d", rois_batch_size, batch_size)); - auto* rois_num_data = rois_num_t->data(); - rois_lod[0] = 0; - for (int n = 0; n < rois_batch_size; ++n) { - rois_lod[n + 1] = rois_lod[n] + rois_num_data[n]; + + std::vector rois_num_list(rois_batch_size); + memory::Copy(cplace, rois_num_list.data(), xplace, + rois_num_t->data(), sizeof(int) * rois_batch_size); + cpu_lod = new int[rois_batch_size + 1]; + cpu_lod[0] = 0; + for (int i = 0; i < rois_batch_size; i++) { + cpu_lod[i + 1] = cpu_lod[i] + rois_num_list[i]; } } else { - auto _rois_lod = rois->lod().back(); - rois_batch_size = _rois_lod.size() - 1; - for (int n = 0; n < static_cast(_rois_lod.size()); ++n) { - rois_lod[n] = _rois_lod[n]; - } + auto lod = rois->lod(); + PADDLE_ENFORCE_EQ( + lod.empty(), false, + platform::errors::InvalidArgument("Input(ROIs) in ROIAlignOp does " + "not contain LoD information.")); + auto rois_lod = lod.back(); + rois_batch_size = rois_lod.size() - 1; PADDLE_ENFORCE_EQ( rois_batch_size, batch_size, platform::errors::InvalidArgument( - "The rois_batch_size and imgs batch_size of roi_align_xpu OP " - "must " - "be the same. But received rois_batch_size %d , batch_size %d", + "The batch size of rois and batch size " + "of images must be the same. But received rois batch size = %d, " + "and images batch size = %d", rois_batch_size, batch_size)); + int rois_num_with_lod = rois_lod[rois_batch_size]; + PADDLE_ENFORCE_EQ( + rois_num, rois_num_with_lod, + platform::errors::InvalidArgument( + "The actual number of rois and the number of rois " + "provided from Input(RoIsLoD) in RoIAlign must be the same." + " But received actual number of rois is %d, and the number " + "of rois from RoIsLoD is %d", + rois_num, rois_num_with_lod)); + for (int n = 0; n < rois_batch_size; ++n) { + for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { + roi_batch_id_data[i] = n; + } + } + cpu_lod = new int[rois_batch_size + 1]; + for (int i = 0; i < rois_batch_size + 1; i++) { + cpu_lod[i] = rois_lod[i]; + } } - int rois_num_with_lod = rois_lod[rois_batch_size]; - PADDLE_ENFORCE_EQ( - rois_num, rois_num_with_lod, - platform::errors::InvalidArgument( - "The rois_num from input and lod of roi_align_xpu OP must be the " - "same. But received input rois_num %d , input lod %d", - rois_num, rois_num_with_lod)); - T* output_data = out->mutable_data(ctx.GetPlace()); - const T* rois_data = rois->data(); - for (int n = 0; n < rois_batch_size; n++) { - int cur_batch_rois_num = rois_lod[n + 1] - rois_lod[n]; - if (cur_batch_rois_num != 0) { - int r = xpu::roi_align( - dev_ctx.x_context(), input_data + n * channels * height * width, - rois_data + rois_lod[n] * 4, cur_batch_rois_num, channels, height, - width, pooled_height, pooled_width, sampling_ratio, spatial_scale, - output_data + - rois_lod[n] * channels * pooled_height * pooled_width); - PADDLE_ENFORCE_EQ( - r, xpu::Error_t::SUCCESS, - platform::errors::External( - "The roi_align XPU OP return wrong value[%d], please check " - "where Baidu Kunlun Card is properly installed.", - r)); + + int* roi_id_data = nullptr; + int r = xpu_malloc(reinterpret_cast(&roi_id_data), + (rois_batch_size + 1) * sizeof(int)); + PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS, + platform::errors::External("no enough memory in xpu")); + memory::Copy(xplace, roi_id_data, cplace, cpu_lod, + (rois_batch_size + 1) * sizeof(int)); + delete[] cpu_lod; + r = xpu::roi_align( + dev_ctx.x_context(), in->data(), + out->mutable_data(ctx.GetPlace()), rois->data(), roi_id_data, + batch_size, channels, height, width, out->dims()[0], pooled_height, + pooled_width, spatial_scale, sampling_ratio, true); + PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS, + platform::errors::External( + "The roi_align XPU OP return wrong value[%d %s]", r, + XPUAPIErrorMsg[r])); + if (dev_ctx.x_context()->xpu_stream) { + dev_ctx.Wait(); + } + xpu_free(roi_id_data); + } +}; + +template +class XPUROIAlignGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* rois = ctx.Input("ROIs"); + + auto* out_grad = ctx.Input(framework::GradVarName("Out")); + auto* in_grad = ctx.Output(framework::GradVarName("X")); + + auto pooled_height = ctx.Attr("pooled_height"); + auto pooled_width = ctx.Attr("pooled_width"); + auto spatial_scale = ctx.Attr("spatial_scale"); + auto sampling_ratio = ctx.Attr("sampling_ratio"); + + int rois_num = rois->dims()[0]; + int channels = in->dims()[1]; + int height = in->dims()[2]; + int width = in->dims()[3]; + + if (!in_grad) { + return; + } + Tensor roi_batch_id_list; + roi_batch_id_list.Resize({rois_num}); + auto cplace = platform::CPUPlace(); + + auto& dev_ctx = ctx.template device_context(); + auto xplace = BOOST_GET_CONST(platform::XPUPlace, ctx.GetPlace()); + + int rois_batch_size = 0; + int* cpu_lod = nullptr; + if (ctx.HasInput("RoisNum")) { + auto* rois_num_t = ctx.Input("RoisNum"); + rois_batch_size = rois_num_t->numel(); + std::vector rois_num_list(rois_batch_size); + memory::Copy(cplace, rois_num_list.data(), xplace, + rois_num_t->data(), sizeof(int) * rois_batch_size); + cpu_lod = new int[rois_batch_size + 1]; + cpu_lod[0] = 0; + for (int i = 0; i < rois_batch_size; i++) { + cpu_lod[i + 1] = cpu_lod[i] + rois_num_list[i]; + } + } else { + auto rois_lod = rois->lod().back(); + rois_batch_size = rois_lod.size() - 1; + cpu_lod = new int[rois_batch_size + 1]; + for (int i = 0; i < rois_batch_size + 1; i++) { + cpu_lod[i] = rois_lod[i]; } } + int* roi_id_data = nullptr; + int r = xpu_malloc(reinterpret_cast(&roi_id_data), + (rois_batch_size + 1) * sizeof(int)); + PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS, + platform::errors::External("no enough memory in xpu")); + memory::Copy(xplace, roi_id_data, cplace, cpu_lod, + (rois_batch_size + 1) * sizeof(int)); + in_grad->mutable_data(ctx.GetPlace()); + + int output_grad_size = out_grad->numel(); + + delete[] cpu_lod; + if (output_grad_size > 0) { + r = xpu::roi_align_grad( + dev_ctx.x_context(), out_grad->data(), in_grad->data(), + rois->data(), roi_id_data, in->dims()[0], channels, height, width, + out_grad->dims()[0], pooled_height, pooled_width, spatial_scale, + sampling_ratio, true); + PADDLE_ENFORCE_EQ( + r, xpu::Error_t::SUCCESS, + platform::errors::External( + "The roi_align_grad XPU OP return wrong value[%d %s]", r, + XPUAPIErrorMsg[r])); + } + if (dev_ctx.x_context()->xpu_stream) { + dev_ctx.Wait(); + } + xpu_free(roi_id_data); } }; } // namespace operators } // namespace paddle + namespace ops = paddle::operators; REGISTER_OP_XPU_KERNEL( roi_align, ops::XPUROIAlignOpKernel); +REGISTER_OP_XPU_KERNEL( + roi_align_grad, + ops::XPUROIAlignGradOpKernel); #endif diff --git a/python/paddle/fluid/tests/unittests/xpu/test_affine_channel_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_affine_channel_op_xpu.py new file mode 100644 index 0000000000..3385d671d7 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_affine_channel_op_xpu.py @@ -0,0 +1,148 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Unit testing for affine_channel_op +""" + +from __future__ import print_function + +import sys +sys.path.append("..") + +import unittest +import numpy as np +from op_test_xpu import XPUOpTest +import paddle +import paddle.fluid.core as core +import paddle.fluid as fluid + + +def affine_channel(x, scale, bias, layout): + C = x.shape[1] if layout == 'NCHW' else x.shape[-1] + if len(x.shape) == 4: + new_shape = (1, C, 1, 1) if layout == 'NCHW' else (1, 1, 1, C) + else: + new_shape = (1, C) + scale = scale.reshape(new_shape) + bias = bias.reshape(new_shape) + return x * scale + bias + + +class TestAffineChannelOp(XPUOpTest): + def setUp(self): + self.op_type = "affine_channel" + self.init_test_case() + + x = np.random.random(self.shape).astype("float32") + scale = np.random.random(self.C).astype("float32") + bias = np.random.random(self.C).astype("float32") + + y = affine_channel(x, scale, bias, self.layout) + + self.inputs = {'X': x, 'Scale': scale, 'Bias': bias} + self.attrs = {'data_layout': self.layout} + self.outputs = {'Out': y} + + def test_check_output(self): + if core.is_compiled_with_xpu(): + paddle.enable_static() + place = paddle.XPUPlace(0) + self.check_output_with_place(place) + + def test_check_grad(self): + if core.is_compiled_with_xpu(): + paddle.enable_static() + place = paddle.XPUPlace(0) + self.check_grad_with_place(place, ['X', 'Scale', 'Bias'], 'Out') + + def test_check_grad_stopgrad_dx(self): + if core.is_compiled_with_xpu(): + paddle.enable_static() + place = paddle.XPUPlace(0) + self.check_grad_with_place( + place, ['Scale', 'Bias'], 'Out', no_grad_set=set('X')) + + def test_check_grad_stopgrad_dscale_dbias(self): + if core.is_compiled_with_xpu(): + paddle.enable_static() + place = paddle.XPUPlace(0) + self.check_grad_with_place( + place, ['X'], 'Out', no_grad_set=set(['Scale', 'Bias'])) + + def init_test_case(self): + self.shape = [2, 100, 3, 3] + self.C = 100 + self.layout = 'NCHW' + + +class TestAffineChannelOpError(unittest.TestCase): + def test_errors(self): + with fluid.program_guard(fluid.Program()): + + def test_x_type(): + input_data = np.random.random(2, 1, 2, 2).astype("float32") + fluid.layers.affine_channel(input_data) + + self.assertRaises(TypeError, test_x_type) + + def test_x_dtype(): + x2 = fluid.layers.data( + name='x2', shape=[None, 1, 2, 2], dtype='int32') + fluid.layers.affine_channel(x2) + + self.assertRaises(TypeError, test_x_dtype) + + def test_scale_type(): + x3 = fluid.layers.data( + name='x3', shape=[None, 1, 2, 2], dtype='float32') + fluid.layers.affine_channel(x3, scale=1) + + self.assertRaises(TypeError, test_scale_type) + + def test_bias_type(): + x4 = fluid.layers.data( + name='x4', shape=[None, 1, 2, 2], dtype='float32') + fluid.layers.affine_channel(x4, bias=1) + + self.assertRaises(TypeError, test_bias_type) + + +class TestAffineChannelNHWC(TestAffineChannelOp): + def init_test_case(self): + self.shape = [2, 3, 3, 100] + self.C = 100 + self.layout = 'NHWC' + + def test_check_grad_stopgrad_dx(self): + return + + def test_check_grad_stopgrad_dscale_dbias(self): + return + + +class TestAffineChannel2D(TestAffineChannelOp): + def init_test_case(self): + self.shape = [2, 100] + self.C = 100 + self.layout = 'NCHW' + + def test_check_grad_stopgrad_dx(self): + return + + def test_check_grad_stopgrad_dscale_dbias(self): + return + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_roi_align_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_roi_align_op_xpu.py index 70f03edb6b..2122223dbe 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_roi_align_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_roi_align_op_xpu.py @@ -20,13 +20,13 @@ import math import numpy as np import paddle.fluid.core as core from op_test import OpTest, skip_check_grad_ci +from op_test_xpu import XPUOpTest import paddle import paddle.fluid as fluid from paddle.fluid import Program, program_guard -@skip_check_grad_ci(reason="There is no grad kernel for roi_align_xpu kernel.") -class TestROIAlignOp(OpTest): +class TestROIAlignOp(XPUOpTest): def set_data(self): self.init_test_case() self.make_rois() @@ -59,16 +59,16 @@ class TestROIAlignOp(OpTest): self.pooled_width = 2 self.sampling_ratio = -1 - self.x = np.random.random(self.x_dim).astype('float64') + self.x = np.random.random(self.x_dim).astype('float32') def pre_calc(self, x_i, roi_xmin, roi_ymin, roi_bin_grid_h, roi_bin_grid_w, bin_size_h, bin_size_w): count = roi_bin_grid_h * roi_bin_grid_w bilinear_pos = np.zeros( [self.channels, self.pooled_height, self.pooled_width, count, 4], - np.float64) + np.float32) bilinear_w = np.zeros( - [self.pooled_height, self.pooled_width, count, 4], np.float64) + [self.pooled_height, self.pooled_width, count, 4], np.float32) for ph in range(self.pooled_width): for pw in range(self.pooled_height): c = 0 @@ -118,7 +118,7 @@ class TestROIAlignOp(OpTest): def calc_roi_align(self): self.out_data = np.zeros( (self.rois_num, self.channels, self.pooled_height, - self.pooled_width)).astype('float64') + self.pooled_width)).astype('float32') for i in range(self.rois_num): roi = self.rois[i] @@ -166,7 +166,7 @@ class TestROIAlignOp(OpTest): roi = [bno, x1, y1, x2, y2] rois.append(roi) self.rois_num = len(rois) - self.rois = np.array(rois).astype("float64") + self.rois = np.array(rois).astype("float32") def setUp(self): self.op_type = "roi_align" @@ -178,6 +178,12 @@ class TestROIAlignOp(OpTest): place = paddle.XPUPlace(0) self.check_output_with_place(place) + def test_check_grad(self): + if core.is_compiled_with_xpu(): + paddle.enable_static() + place = paddle.XPUPlace(0) + self.check_grad_with_place(place, {'X'}, 'Out') + class TestROIAlignInLodOp(TestROIAlignOp): def set_data(self): -- GitLab