From 6bfc57215fbe0f9f876dfce45ca67ec10c4f7e2b Mon Sep 17 00:00:00 2001 From: liuyuhui Date: Tue, 8 Dec 2020 20:29:54 +0800 Subject: [PATCH] [2.0 rc1/cherrypick] cherry-pick kunlun PR:29234/29229/29293/29367/29280/29448 (#29466) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add deformable_conv op on xpu (#29234) * rebase develop * update deformable_conv op on xpu * update deformable_conv op on xpu * update kunlun conv2d/softmax/elementwise implemetation (#29229) * update conv2d & softmax to new xpu api * test=kunlun * remove useless comments * test=kunlun * remote softmax xpu op * test=kunlun * update kunlun softmax * test=kunlun * update xpu unitest * test=kunlun * fix elementwise_grad bug for kunlun *test=kunlun * support global pooling for kunlun (#29293) * test=kunlun * update reduce_sum op on xpu (#29367) * update reduce_sum op on xpu * update reduce_sum op on xpu * support running on xpu * fix expand/uniform_random && concat/transpose to new api on xpu (#29280) * fix expand && concat/transpose to new api * update uniform_random_op * update xpu_header * 1. fix elementwise ops'bug 2. fix softmax_with_cross_entropy_op 3. add biliner_interp_op (#29448) Co-authored-by: root Co-authored-by: 卖鱼的哲学 Co-authored-by: QingshuChen Co-authored-by: taixiurong Co-authored-by: root --- cmake/external/xpu.cmake | 2 +- paddle/fluid/operators/concat_op_xpu.cc | 127 +++-- paddle/fluid/operators/conv_op_xpu.cc | 125 +---- .../fluid/operators/deformable_conv_op_xpu.cc | 286 ++++++++++ .../elementwise/elementwise_div_op_xpu.cc | 3 +- .../elementwise/elementwise_max_op_xpu.cc | 3 +- .../elementwise/elementwise_min_op_xpu.cc | 3 +- .../elementwise/elementwise_mul_op_xpu.cc | 3 +- .../operators/elementwise/elementwise_xpu.h | 52 +- paddle/fluid/operators/expand_op.h | 16 +- paddle/fluid/operators/interpolate_op_xpu.cc | 258 +++++++++ paddle/fluid/operators/pool_op_xpu.cc | 24 +- .../operators/reduce_ops/reduce_sum_op_xpu.cc | 162 +++--- paddle/fluid/operators/softmax_op_xpu.cc | 51 +- .../softmax_with_cross_entropy_op_xpu.cc | 3 +- paddle/fluid/operators/transpose_op_xpu.cc | 100 +--- .../fluid/operators/uniform_random_op_xpu.cc | 51 +- paddle/fluid/platform/xpu_header.h | 9 + python/paddle/fluid/io.py | 4 + .../fluid/tests/unittests/op_test_xpu.py | 11 - .../xpu/test_bilinear_interp_op_xpu.py | 519 ++++++++++++++++++ .../tests/unittests/xpu/test_concat_op_xpu.py | 98 +--- .../xpu/test_deformable_conv_op_xpu.py | 274 +++++++++ .../unittests/xpu/test_reduce_sum_op_xpu.py | 223 +++----- .../test_softmax_with_cross_entropy_op_xpu.py | 9 +- .../unittests/xpu/test_transpose_op_xpu.py | 116 +--- 26 files changed, 1793 insertions(+), 739 deletions(-) create mode 100644 paddle/fluid/operators/deformable_conv_op_xpu.cc create mode 100644 paddle/fluid/operators/interpolate_op_xpu.cc create mode 100755 python/paddle/fluid/tests/unittests/xpu/test_bilinear_interp_op_xpu.py create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_deformable_conv_op_xpu.py diff --git a/cmake/external/xpu.cmake b/cmake/external/xpu.cmake index 8d3fee915c4..a709616314b 100644 --- a/cmake/external/xpu.cmake +++ b/cmake/external/xpu.cmake @@ -4,7 +4,7 @@ endif() INCLUDE(ExternalProject) SET(XPU_PROJECT "extern_xpu") -SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2020_11_10.tar.gz" CACHE STRING "" FORCE) +SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2020_12_04.tar.gz" CACHE STRING "" FORCE) SET(XPU_SOURCE_DIR "${THIRD_PARTY_PATH}/xpu") SET(XPU_DOWNLOAD_DIR "${XPU_SOURCE_DIR}/src/${XPU_PROJECT}") SET(XPU_INSTALL_DIR "${THIRD_PARTY_PATH}/install/xpu") diff --git a/paddle/fluid/operators/concat_op_xpu.cc b/paddle/fluid/operators/concat_op_xpu.cc index 9c9c72c7f6f..0558f09a174 100644 --- a/paddle/fluid/operators/concat_op_xpu.cc +++ b/paddle/fluid/operators/concat_op_xpu.cc @@ -11,18 +11,12 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - +#ifdef PADDLE_WITH_XPU #include "paddle/fluid/operators/concat_op.h" - #include #include #include - -#ifdef PADDLE_WITH_MKLDNN -#include -#endif - -#ifdef PADDLE_WITH_XPU +#include "paddle/fluid/platform/xpu_header.h" namespace paddle { namespace operators { @@ -32,8 +26,8 @@ template class ConcatXPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto ins = ctx.MultiInput("X"); - framework::Tensor* out = ctx.Output("Out"); + auto ins = ctx.MultiInput("X"); + framework::LoDTensor* out = ctx.Output("Out"); int axis = ctx.Attr("axis"); PADDLE_ENFORCE_NE(ins[0], nullptr, platform::errors::InvalidArgument( "The input should not be null.")); @@ -47,6 +41,7 @@ class ConcatXPUKernel : public framework::OpKernel { PADDLE_ENFORCE_LT(axis, ins[0]->dims().size(), platform::errors::InvalidArgument( "concat: axis shoud < ins[0]->dims()!")); + auto place = ctx.GetPlace(); out->mutable_data(place); std::vector choose_idx; @@ -57,43 +52,54 @@ class ConcatXPUKernel : public framework::OpKernel { n++; } } - PADDLE_ENFORCE_LE(n, 8, platform::errors::InvalidArgument( - "XPU only surpport at most 8 tensors for now")); PADDLE_ENFORCE_GT( n, 0, platform::errors::InvalidArgument("No tensor need concat?")); - int h = 1; - int w_except_axis = 1; - for (int i = 0; i < axis; ++i) { - h *= (ins[choose_idx[0]]->dims())[i]; - } - for (int i = axis + 1; i < ins[0]->dims().size(); ++i) { - w_except_axis *= (ins[choose_idx[0]]->dims())[i]; - } - for (int i = 1; i < n; ++i) { - int hh = 1; - int ww = 1; - for (int j = 0; j < axis; ++j) { - hh *= (ins[choose_idx[i]]->dims())[j]; + + // If axis is 0, the lod of the output is not the same as inputs. + if (axis == 0 && ins[0]->lod().size() > 0) { + size_t lod_size_0 = ins[0]->lod().size(); + size_t lod_size = lod_size_0; + for (size_t i = 1; i < ins.size(); ++i) { + if (ins[i]->lod().size() > 0) { + PADDLE_ENFORCE_EQ( + ins[i]->lod().size(), lod_size_0, + platform::errors::Unimplemented( + "The lod level of all input LoDTensors should be same. " + "Maybe different lod level of input LoDTensors can concat," + "it is not supported currently. The lod level of %dth input " + "is %d and first input is %d.", + i, ins[i]->lod().size(), lod_size_0)); + } else { + lod_size = 0; + break; + } } - for (int j = axis + 1; j < ins[i]->dims().size(); ++j) { - ww *= (ins[choose_idx[i]]->dims())[j]; + if (lod_size) { + auto* out_lod = out->mutable_lod(); + for (size_t i = 1; i < ins.size(); ++i) { + auto in_lod = ConvertToLengthBasedLoD(ins[i]->lod()); + AppendLoD(out_lod, in_lod); + } } - PADDLE_ENFORCE_EQ(hh, h, platform::errors::InvalidArgument( - "concat: h should be eual!")); - PADDLE_ENFORCE_EQ(ww, w_except_axis, - platform::errors::InvalidArgument( - "concat: w should be eual except for axis!")); } + + auto input_dims = ins[0]->dims(); + std::vector> xdims_list(n); + for (int i = 0; i < n; ++i) { + std::vector tmp_dims(input_dims.size()); + for (int j = 0; j < input_dims.size(); ++j) { + tmp_dims[j] = ins[i]->dims()[j]; + } + xdims_list[i] = tmp_dims; + } + auto& dev_ctx = ctx.template device_context(); - std::unique_ptr in_w_host(new int[n]); - std::unique_ptr ptrs(new const float*[n]); + std::vector ptrs; for (int i = 0; i < n; ++i) { - ptrs[i] = ins[choose_idx[i]]->data(); - in_w_host[i] = w_except_axis * (ins[choose_idx[i]]->dims())[axis]; + ptrs.push_back(ins[choose_idx[i]]->data()); } - int r = - xpu::concat(dev_ctx.x_context(), h, (const int*)in_w_host.get(), - n, (const float**)ptrs.get(), out->data()); + int r = xpu::concat(dev_ctx.x_context(), ptrs, out->data(), + xdims_list, axis); PADDLE_ENFORCE_EQ( r, XPU_SUCCESS, platform::errors::External( @@ -102,6 +108,7 @@ class ConcatXPUKernel : public framework::OpKernel { r)); } }; + template class ConcatGradXPUKernel : public framework::OpKernel { public: @@ -132,13 +139,15 @@ class ConcatGradXPUKernel : public framework::OpKernel { static_cast(ins[0]->dims().size())); // get output tensor that the name is not kEmptyVarName std::vector outputs; + std::vector choose_idx; + int n = 0; for (size_t j = 0; j < outs.size(); ++j) { if (out_var_names[j] != framework::kEmptyVarName && outs[j]->numel() != 0UL) { outs[j]->mutable_data(ctx.GetPlace()); outputs.push_back(outs[j]); - } else { - outputs.push_back(nullptr); + choose_idx.push_back(j); + n++; } } PADDLE_ENFORCE_GE(axis, 0, platform::errors::InvalidArgument( @@ -146,23 +155,31 @@ class ConcatGradXPUKernel : public framework::OpKernel { PADDLE_ENFORCE_LT(axis, out_grad->dims().size(), platform::errors::InvalidArgument( "concat_grad: axis shoud < ins[0]->dims()!")); - auto out_grad_stride = framework::stride_numel(out_grad->dims()); - int n = outputs.size(); - PADDLE_ENFORCE_LE(n, 16, - platform::errors::InvalidArgument( - "XPU only surpport at most 16 tensors for now")); - int h = out_grad_stride[0] / out_grad_stride[axis]; - auto& dev_ctx = ctx.template device_context(); - std::unique_ptr in_w_host(new int[n]); - std::unique_ptr ptrs(new float*[n]); + + auto input_dims = ins[0]->dims(); + std::vector split_list(n); + std::vector xdims_list(input_dims.size()); + int total_length = 0; + for (int i = 0; i < n; ++i) { + split_list[i] = ins[i]->dims()[axis]; + total_length += ins[i]->dims()[axis]; + } + for (int i = 0; i < input_dims.size(); ++i) { + if (i == axis) { + continue; + } + xdims_list[i] = input_dims[i]; + } + xdims_list[axis] = total_length; + + std::vector ptrs(n); for (int i = 0; i < n; ++i) { - auto out_stride = framework::stride_numel(outputs[i]->dims()); ptrs[i] = outputs[i]->data(); - in_w_host[i] = out_stride[axis]; } - int r = xpu::concat_grad(dev_ctx.x_context(), h, in_w_host.get(), n, - reinterpret_cast(ptrs.get()), - out_grad->data()); + + auto& dev_ctx = ctx.template device_context(); + int r = xpu::split(dev_ctx.x_context(), out_grad->data(), ptrs, + xdims_list, split_list, axis); PADDLE_ENFORCE_EQ( r, XPU_SUCCESS, platform::errors::External( diff --git a/paddle/fluid/operators/conv_op_xpu.cc b/paddle/fluid/operators/conv_op_xpu.cc index 65ed34e8a5e..46af4d30500 100644 --- a/paddle/fluid/operators/conv_op_xpu.cc +++ b/paddle/fluid/operators/conv_op_xpu.cc @@ -27,10 +27,6 @@ class GemmConvXPUKernel : public framework::OpKernel { // that avoids modifying the variable in the Scope. Tensor filter = *context.Input("Filter"); Tensor* output = context.Output("Output"); - // Tensor* max_input = context.Output("MaxInput"); - // Tensor* max_filter = context.Output("MaxFilter"); - // max_input->mutable_data(context.GetPlace()); - // max_filter->mutable_data(context.GetPlace()); output->mutable_data(context.GetPlace()); int groups = context.Attr("groups"); std::vector strides = context.Attr>("strides"); @@ -43,52 +39,18 @@ class GemmConvXPUKernel : public framework::OpKernel { const int f = static_cast(filter.dims()[0]); const int win_h = static_cast(filter.dims()[2]); const int win_w = static_cast(filter.dims()[3]); - PADDLE_ENFORCE_EQ( - dilations[0] == 1 && dilations[1] == 1, true, - platform::errors::InvalidArgument("XPU only support dilation == 1.")); auto& dev_ctx = context.template device_context(); - // PADDLE_ENFORCE_EQ( - // xpu::findmax(dev_ctx.x_context(), input->data(), input->numel(), - // max_input->data()) == xpu::Error_t::SUCCESS, - // true, platform::errors::InvalidArgument( - // "XPU conv kernel error,can not finde max_input,please " - // "check whether Baidu Kunlun " - // "Card is properly installed.")); - // PADDLE_ENFORCE_EQ( - // xpu::findmax(dev_ctx.x_context(), filter.data(), filter.numel(), - // max_filter->data()) == xpu::Error_t::SUCCESS, - // true, platform::errors::InvalidArgument( - // "XPU conv kernel error,can not find max_filter,please " - // "check whether Baidu Kunlun " - // "Card is properly installed.")); - if (groups == 1) { - int r = xpu::conv2d_forward_int16( - dev_ctx.x_context(), batch_size, img_c, img_h, img_w, f, win_h, win_w, - strides[0], strides[1], paddings[0], paddings[1], dilations[0], - dilations[1], groups, input->data(), filter.data(), - output->data(), nullptr, nullptr, xpu::Activation_t::LINEAR, - nullptr, nullptr); - // max_input->data(), max_filter->data()); - PADDLE_ENFORCE_EQ( - r, XPU_SUCCESS, - platform::errors::External("XPU conv kernel return wrong value[%d], " - "please check whether Baidu Kunlun Card " - "is properly installed.", - r)); - } else { - int r = xpu::conv2d_int16_with_group( - dev_ctx.x_context(), input->data(), filter.data(), - output->data(), batch_size, img_c, img_h, img_w, f, win_h, - win_w, groups, strides[0], strides[1], paddings[0], paddings[1], - nullptr, nullptr); - // max_input->data(), max_filter->data()); - PADDLE_ENFORCE_EQ( - r, XPU_SUCCESS, - platform::errors::External("XPU conv kernel return wrong value[%d], " - "please check whether Baidu Kunlun Card " - "is properly installed.", - r)); - } + std::vector k_size; + k_size.push_back(win_h); + k_size.push_back(win_w); + int r = xpu::conv2d( + dev_ctx.x_context(), input->data(), filter.data(), + output->data(), batch_size, img_c, img_h, img_w, f, k_size, + strides, paddings, dilations, groups, nullptr, nullptr, nullptr, true); + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External("XPU conv kernel return wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); } }; template @@ -96,9 +58,6 @@ class GemmConvGradXPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { const Tensor* input = context.Input("Input"); - // const Tensor* max_input = context.Input("MaxInput"); - // const Tensor* max_filter = context.Input("MaxFilter"); - // Tensor* max_output_grad = context.Output("MaxOutputGrad"); const Tensor* output_grad = context.Input(framework::GradVarName("Output")); Tensor* input_grad = @@ -115,11 +74,6 @@ class GemmConvGradXPUKernel : public framework::OpKernel { std::vector paddings = context.Attr>("paddings"); std::vector dilations = context.Attr>("dilations"); const int batch_size = static_cast(input->dims()[0]); - PADDLE_ENFORCE_EQ(groups == 1, true, platform::errors::InvalidArgument( - "XPU only support groups == 1.")); - PADDLE_ENFORCE_EQ( - dilations[0] == 1 && dilations[1] == 1, true, - platform::errors::InvalidArgument("XPU only support dilation == 1.")); const int img_c = static_cast(input->dims()[1]); const int img_h = static_cast(input->dims()[2]); const int img_w = static_cast(input->dims()[3]); @@ -133,52 +87,24 @@ class GemmConvGradXPUKernel : public framework::OpKernel { filter_grad->mutable_data(context.GetPlace()); } auto& dev_ctx = context.template device_context(); - // max_output_grad->Resize({4}); - // max_output_grad->mutable_data(context.GetPlace()); - // PADDLE_ENFORCE_EQ( - // xpu::findmax(dev_ctx.x_context(), output_grad->data(), - // output_grad->numel(), - // max_output_grad->data()) == xpu::Error_t::SUCCESS, - // true, - // platform::errors::External( - // "XPU conv kernel error, can not find max_output_grad, please - // check " - // "whether Baidu Kunlun Card is " - // "properly installed.")); - if (input_grad) { - int r = xpu::conv2d_backward_int16( - dev_ctx.x_context(), batch_size, img_c, img_h, img_w, f, win_h, win_w, - strides[0], strides[1], paddings[0], paddings[1], dilations[0], - dilations[1], groups, output_grad->data(), - filter.data(), input_grad->data(), nullptr, nullptr); - // max_output_grad->data(), max_filter->data()); - PADDLE_ENFORCE_EQ( - r, XPU_SUCCESS, - platform::errors::External("XPU conv kernel return wrong value[%d], " - "please check whether Baidu Kunlun Card " - "is properly installed.", - r)); - } - if (filter_grad) { - int r = xpu::conv2d_backward_weight_int16( - dev_ctx.x_context(), batch_size, img_c, img_h, img_w, f, win_h, win_w, - strides[0], strides[1], paddings[0], paddings[1], dilations[0], - dilations[1], groups, output_grad->data(), - input->data(), filter_grad->data(), nullptr, nullptr); - // max_output_grad->data(), max_input->data()); - PADDLE_ENFORCE_EQ( - r, XPU_SUCCESS, - platform::errors::External("XPU conv kernel return wrong value[%d], " - "please check whether Baidu Kunlun Card " - "is properly installed.", - r)); - } + std::vector k_size; + k_size.push_back(win_h); + k_size.push_back(win_w); + int r = xpu::conv2d_grad( + dev_ctx.x_context(), input->data(), filter.data(), + output_grad->data(), input_grad ? input_grad->data() : nullptr, + filter_grad ? filter_grad->data() : nullptr, batch_size, img_c, + img_h, img_w, f, k_size, strides, paddings, dilations, groups, nullptr, + nullptr, nullptr, nullptr, nullptr, true); + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External("XPU conv kernel return wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; -// TODO(xingzhaolong): neon kernel for mobile REGISTER_OP_XPU_KERNEL( depthwise_conv2d, ops::GemmConvXPUKernel); @@ -187,4 +113,7 @@ REGISTER_OP_XPU_KERNEL( REGISTER_OP_XPU_KERNEL( conv2d_grad, ops::GemmConvGradXPUKernel); +REGISTER_OP_XPU_KERNEL( + depthwise_conv2d_grad, + ops::GemmConvGradXPUKernel); #endif diff --git a/paddle/fluid/operators/deformable_conv_op_xpu.cc b/paddle/fluid/operators/deformable_conv_op_xpu.cc new file mode 100644 index 00000000000..18bab83b0ed --- /dev/null +++ b/paddle/fluid/operators/deformable_conv_op_xpu.cc @@ -0,0 +1,286 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_XPU +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/xpu_header.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class DeformableConvXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("Input"); + auto* offset = ctx.Input("Offset"); + auto* mask = ctx.Input("Mask"); + Tensor filter = *ctx.Input("Filter"); + Tensor* output = ctx.Output("Output"); + output->mutable_data(ctx.GetPlace()); + + auto& dev_ctx = ctx.template device_context(); + + const int groups = ctx.Attr("groups"); + const int deformable_groups = ctx.Attr("deformable_groups"); + const int im2col_step = ctx.Attr("im2col_step"); + const std::vector strides = ctx.Attr>("strides"); + const std::vector paddings = ctx.Attr>("paddings"); + const std::vector dilations = ctx.Attr>("dilations"); + + PADDLE_ENFORCE_EQ( + deformable_groups == 1, true, + platform::errors::InvalidArgument(( + "XPU only support deformable_groups == 1 in deformable_conv op."))); + PADDLE_ENFORCE_EQ( + groups == 1, true, + platform::errors::InvalidArgument( + ("XPU only support groups == 1 in deformable_conv op."))); + PADDLE_ENFORCE_EQ(filter.dims()[2] <= 8 && filter.dims()[3] <= 8, true, + platform::errors::InvalidArgument( + "Filter high and weight should less than 8 on xpu " + "in deformable_conv op.")); + + const int batch_size = static_cast(input->dims()[0]); + std::vector output_shape_vec(framework::vectorize(output->dims())); + + const T* input_ptr = input->data(); + const T* filter_ptr = filter.data(); + const float* offset_ptr = offset->data(); + const float* mask_ptr = mask->data(); + T* output_prt = output->data(); + + // set zeros for d_table_data + const int zero = 0; + int r = xpu::constant(dev_ctx.x_context(), output_prt, output->numel(), + zero); + PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true, + platform::errors::External( + "XPU API return wrong value[%d], please check where " + "Baidu Kunlun Card is properly installed.", + r)); + int input_dim = input->numel() / input->dims()[0]; + int input_offset_dim = offset->numel() / offset->dims()[0]; + int input_mask_dim = mask->numel() / mask->dims()[0]; + int output_dim = + output_shape_vec[1] * output_shape_vec[2] * output_shape_vec[3]; + std::vector ksize{static_cast(filter.dims()[2]), + static_cast(filter.dims()[3])}; + int n = im2col_step; + int c = input->dims()[1]; + int h = input->dims()[2]; + int w = input->dims()[3]; + int f = filter.dims()[0]; + + for (int i = 0; i < batch_size / im2col_step; ++i) { + int r = xpu::deformable_conv( + dev_ctx.x_context(), input_ptr + i * im2col_step * input_dim, + filter_ptr, offset_ptr + i * im2col_step * input_offset_dim, + mask_ptr + i * im2col_step * input_mask_dim, + output_prt + i * im2col_step * output_dim, n, c, h, w, f, ksize, + strides, paddings, dilations, groups, deformable_groups, nullptr, + nullptr, nullptr, true); + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External( + "XPU deformable_conv kernel return wrong value[%d].", r)); + } + } +}; + +template +class DeformableConvGradXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const Tensor* output_grad = + ctx.Input(framework::GradVarName("Output")); + Tensor* input_grad = ctx.Output(framework::GradVarName("Input")); + Tensor* filter_grad = ctx.Output(framework::GradVarName("Filter")); + Tensor* offset_grad = ctx.Output(framework::GradVarName("Offset")); + Tensor* mask_grad = ctx.Output(framework::GradVarName("Mask")); + T* dx_data = nullptr; + T* dw_data = nullptr; + T* dmask_data = nullptr; + T* doffset_data = nullptr; + + if (input_grad != nullptr) { + input_grad->mutable_data(ctx.GetPlace()); + dx_data = input_grad->data(); + } + if (filter_grad != nullptr) { + filter_grad->mutable_data(ctx.GetPlace()); + dw_data = filter_grad->data(); + } + if (offset_grad != nullptr) { + offset_grad->mutable_data(ctx.GetPlace()); + doffset_data = offset_grad->data(); + } + if (mask_grad != nullptr) { + mask_grad->mutable_data(ctx.GetPlace()); + dmask_data = mask_grad->data(); + } + + const Tensor* input = ctx.Input("Input"); + Tensor offset = *ctx.Input("Offset"); + Tensor mask = *ctx.Input("Mask"); + Tensor filter = *ctx.Input("Filter"); + + int groups = ctx.Attr("groups"); + int deformable_groups = ctx.Attr("deformable_groups"); + int im2col_step = ctx.Attr("im2col_step"); + std::vector strides = ctx.Attr>("strides"); + std::vector paddings = ctx.Attr>("paddings"); + std::vector dilations = ctx.Attr>("dilations"); + + PADDLE_ENFORCE_EQ( + deformable_groups == 1, true, + platform::errors::InvalidArgument(( + "XPU only support deformable_groups == 1 in deformable_conv op."))); + PADDLE_ENFORCE_EQ( + groups == 1, true, + platform::errors::InvalidArgument( + ("XPU only support groups == 1 in deformable_conv op."))); + PADDLE_ENFORCE_EQ(filter.dims()[2] <= 8 && filter.dims()[3] <= 8, true, + platform::errors::InvalidArgument( + "Filter high and weight should less than 8 on xpu " + "in deformable_conv op.")); + + auto& dev_ctx = ctx.template device_context(); + const int batch_size = static_cast(input->dims()[0]); + std::vector output_shape_vec( + framework::vectorize(output_grad->dims())); + const T* output_grad_ptr = output_grad->data(); + const T* input_ptr = input->data(); + const T* filter_ptr = filter.data(); + const float* offset_ptr = offset.data(); + const float* mask_ptr = mask.data(); + if (dx_data == nullptr) { + PADDLE_ENFORCE_EQ(xpu_malloc(reinterpret_cast(&dx_data), + input->numel() * sizeof(T)), + XPU_SUCCESS, platform::errors::ResourceExhausted( + "XPU has no enough memory")); + } + if (dw_data == nullptr) { + PADDLE_ENFORCE_EQ(xpu_malloc(reinterpret_cast(&dw_data), + filter.numel() * sizeof(T)), + XPU_SUCCESS, platform::errors::ResourceExhausted( + "XPU has no enough memory")); + } + if (doffset_data == nullptr) { + PADDLE_ENFORCE_EQ(xpu_malloc(reinterpret_cast(&doffset_data), + offset.numel() * sizeof(T)), + XPU_SUCCESS, platform::errors::ResourceExhausted( + "XPU has no enough memory")); + } + if (dmask_data == nullptr) { + PADDLE_ENFORCE_EQ(xpu_malloc(reinterpret_cast(&dmask_data), + mask.numel() * sizeof(T)), + XPU_SUCCESS, platform::errors::ResourceExhausted( + "XPU has no enough memory")); + } + + int input_dim = input->numel() / input->dims()[0]; + int input_offset_dim = offset.numel() / offset.dims()[0]; + int input_mask_dim = mask.numel() / mask.dims()[0]; + int output_dim = + output_shape_vec[1] * output_shape_vec[2] * output_shape_vec[3]; + std::vector ksize{static_cast(filter.dims()[2]), + static_cast(filter.dims()[3])}; + int n = im2col_step; + int c = input->dims()[1]; + int h = input->dims()[2]; + int w = input->dims()[3]; + int f = filter.dims()[0]; + + T* filter_grad_tmp = nullptr; + PADDLE_ENFORCE_EQ(xpu_malloc(reinterpret_cast(&filter_grad_tmp), + filter_grad->numel() * sizeof(T)), + XPU_SUCCESS, platform::errors::ResourceExhausted( + "XPU has no enough memory")); + + // set zeros for d_table_data + const int zero = 0; + int r_dx = + xpu::constant(dev_ctx.x_context(), dx_data, input->numel(), zero); + int r_dw = + xpu::constant(dev_ctx.x_context(), dw_data, filter.numel(), zero); + int r_doffset = xpu::constant(dev_ctx.x_context(), doffset_data, + offset.numel(), zero); + int r_dmask = + xpu::constant(dev_ctx.x_context(), dmask_data, mask.numel(), zero); + int r_filter = xpu::constant(dev_ctx.x_context(), filter_grad_tmp, + filter.numel(), zero); + auto ret = (r_dx == xpu::Error_t::SUCCESS) && (r_dx == r_dw) && + (r_dx == r_doffset) && (r_dx == r_dmask) && (r_dx == r_filter); + PADDLE_ENFORCE_EQ(ret, true, + platform::errors::External( + "XPU API return wrong value, please check where " + "Baidu Kunlun Card is properly installed.")); + + for (int i = 0; i < batch_size / im2col_step; ++i) { + int r = xpu::deformable_conv_grad( + dev_ctx.x_context(), input_ptr + i * im2col_step * input_dim, + filter_ptr, offset_ptr + i * im2col_step * input_offset_dim, + mask_ptr + i * im2col_step * input_mask_dim, + output_grad_ptr + i * im2col_step * output_dim, + dx_data + i * im2col_step * input_dim, filter_grad_tmp, + doffset_data + i * im2col_step * input_offset_dim, + dmask_data + i * im2col_step * input_mask_dim, n, c, h, w, f, ksize, + strides, paddings, dilations, groups, deformable_groups, nullptr, + nullptr, nullptr, nullptr, nullptr, true); + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External( + "XPU deformable_conv_grad kernel return wrong value[%d].", r)); + r = baidu::xpu::api::add(dev_ctx.x_context(), filter_grad_tmp, dw_data, + dw_data, filter.numel()); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, + platform::errors::External( + "XPU add kernel return wrong value[%d].", r)); + } + + dev_ctx.Wait(); + xpu_free(filter_grad_tmp); + if (input_grad == nullptr) { + xpu_free(dx_data); + } + if (filter_grad == nullptr) { + xpu_free(dw_data); + } + if (offset_grad == nullptr) { + xpu_free(doffset_data); + } + if (mask_grad == nullptr) { + xpu_free(dmask_data); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +using XPUDeviceContext = paddle::platform::XPUDeviceContext; + +REGISTER_OP_XPU_KERNEL(deformable_conv, + ops::DeformableConvXPUKernel); +REGISTER_OP_XPU_KERNEL( + deformable_conv_grad, + ops::DeformableConvGradXPUKernel); + +#endif diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op_xpu.cc b/paddle/fluid/operators/elementwise/elementwise_div_op_xpu.cc index 4f254a53074..da676a7244f 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op_xpu.cc +++ b/paddle/fluid/operators/elementwise/elementwise_div_op_xpu.cc @@ -28,9 +28,10 @@ class ElementwiseDivXPUKernel : public framework::OpKernel { }; template -class ElementwiseDivGradXPUKernel : public framework::OpKernel { +class ElementwiseDivGradXPUKernel : public ElemwiseGradKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + ElemwiseGradKernel::Compute(ctx); XPUElementwiseGrad(ctx, xpu::div_grad, true); } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_max_op_xpu.cc b/paddle/fluid/operators/elementwise/elementwise_max_op_xpu.cc index 411ddb26603..c87db69c57d 100644 --- a/paddle/fluid/operators/elementwise/elementwise_max_op_xpu.cc +++ b/paddle/fluid/operators/elementwise/elementwise_max_op_xpu.cc @@ -29,9 +29,10 @@ class ElementwiseMaxXPUKernel : public framework::OpKernel { }; template -class ElementwiseMaxGradXPUKernel : public framework::OpKernel { +class ElementwiseMaxGradXPUKernel : public ElemwiseGradKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + ElemwiseGradKernel::Compute(ctx); XPUElementwiseGrad(ctx, xpu::max_grad, true); } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_min_op_xpu.cc b/paddle/fluid/operators/elementwise/elementwise_min_op_xpu.cc index 0b1e1312264..f1401369ec6 100644 --- a/paddle/fluid/operators/elementwise/elementwise_min_op_xpu.cc +++ b/paddle/fluid/operators/elementwise/elementwise_min_op_xpu.cc @@ -29,9 +29,10 @@ class ElementwiseMinXPUKernel : public framework::OpKernel { }; template -class ElementwiseMinGradXPUKernel : public framework::OpKernel { +class ElementwiseMinGradXPUKernel : public ElemwiseGradKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + ElemwiseGradKernel::Compute(ctx); XPUElementwiseGrad(ctx, xpu::min_grad, true); } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op_xpu.cc b/paddle/fluid/operators/elementwise/elementwise_mul_op_xpu.cc index 02c6900c7c1..23bb04f60a8 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op_xpu.cc +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op_xpu.cc @@ -27,9 +27,10 @@ class ElementwiseMulXPUKernel : public framework::OpKernel { }; // DEFINE_XPU_GRAD_KERNEL(Mul, mul, true); template -class ElementwiseMulGradXPUKernel : public framework::OpKernel { +class ElementwiseMulGradXPUKernel : public ElemwiseGradKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + ElemwiseGradKernel::Compute(ctx); XPUElementwiseGrad(ctx, xpu::mul_grad, true); } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_xpu.h b/paddle/fluid/operators/elementwise/elementwise_xpu.h index fdf5aeeba53..89d8487fdbb 100644 --- a/paddle/fluid/operators/elementwise/elementwise_xpu.h +++ b/paddle/fluid/operators/elementwise/elementwise_xpu.h @@ -65,7 +65,7 @@ static std::pair, std::vector> XPUReducesAxisVector( } int yidx = 0; for (size_t i = 0; i < x_vector.size(); ++i) { - if (y[yidx] == 1) { + if (yidx >= y.size() || y[yidx] == 1) { axis_v.push_back(i); yidx++; continue; @@ -134,10 +134,10 @@ void XPUElementwise( std::pair, std::vector> bcast_v = XPUDimsToBroadcastVector(framework::make_ddim(x_dims_array), out_dim); - ret = xpu::broadcast( - dev_ctx.x_context(), x_data, - x_broadcast_tensor.mutable_data(ctx.GetPlace(), z->numel()), - bcast_v.first, bcast_v.second); + ret = xpu::broadcast(dev_ctx.x_context(), x_data, + x_broadcast_tensor.mutable_data( + ctx.GetPlace(), z->numel() * sizeof(T)), + bcast_v.first, bcast_v.second); PADDLE_ENFORCE_EQ( ret, xpu::SUCCESS, platform::errors::External( @@ -153,10 +153,10 @@ void XPUElementwise( std::vector bcast_y_v; std::pair, std::vector> bcast_v = XPUDimsToBroadcastVector(framework::make_ddim(y_dims_array), out_dim); - ret = xpu::broadcast( - dev_ctx.x_context(), y_data, - y_broadcast_tensor.mutable_data(ctx.GetPlace(), z->numel()), - bcast_v.first, bcast_v.second); + ret = xpu::broadcast(dev_ctx.x_context(), y_data, + y_broadcast_tensor.mutable_data( + ctx.GetPlace(), z->numel() * sizeof(T)), + bcast_v.first, bcast_v.second); PADDLE_ENFORCE_EQ( ret, xpu::SUCCESS, platform::errors::External( @@ -231,13 +231,15 @@ void XPUElementwiseGrad(const framework::ExecutionContext& ctx, bool dx_need_reduce = (dx != nullptr) && (dx->numel() != len); bool dy_need_reduce = (dy != nullptr) && (dy->numel() != len); - T* dx_data = ((dx == nullptr) || dx_need_reduce) - ? (dx_local_tensor.mutable_data(ctx.GetPlace(), len)) - : (dx->mutable_data(ctx.GetPlace())); + T* dx_data = + ((dx == nullptr) || dx_need_reduce) + ? (dx_local_tensor.mutable_data(ctx.GetPlace(), len * sizeof(T))) + : (dx->mutable_data(ctx.GetPlace())); - T* dy_data = ((dy == nullptr) || dy_need_reduce) - ? (dy_local_tensor.mutable_data(ctx.GetPlace(), len)) - : (dy->mutable_data(ctx.GetPlace())); + T* dy_data = + ((dy == nullptr) || dy_need_reduce) + ? (dy_local_tensor.mutable_data(ctx.GetPlace(), len * sizeof(T))) + : (dy->mutable_data(ctx.GetPlace())); int ret = xpu::SUCCESS; auto& dev_ctx = @@ -250,8 +252,8 @@ void XPUElementwiseGrad(const framework::ExecutionContext& ctx, XPUDimsToBroadcastVector(framework::make_ddim(x_dims_array), out_dim); ret = xpu::broadcast( dev_ctx.x_context(), x_data, - x_broadcast_tensor.mutable_data(ctx.GetPlace(), len), bcast_v.first, - bcast_v.second); + x_broadcast_tensor.mutable_data(ctx.GetPlace(), len * sizeof(T)), + bcast_v.first, bcast_v.second); PADDLE_ENFORCE_EQ(ret, xpu::SUCCESS, platform::errors::External( "XPU kernel broadcast error occur! %d", ret)); @@ -267,8 +269,8 @@ void XPUElementwiseGrad(const framework::ExecutionContext& ctx, XPUDimsToBroadcastVector(framework::make_ddim(y_dims_array), out_dim); ret = xpu::broadcast( dev_ctx.x_context(), y_data, - y_broadcast_tensor.mutable_data(ctx.GetPlace(), len), bcast_v.first, - bcast_v.second); + y_broadcast_tensor.mutable_data(ctx.GetPlace(), len * sizeof(T)), + bcast_v.first, bcast_v.second); PADDLE_ENFORCE_EQ(ret, xpu::SUCCESS, platform::errors::External( "XPU kernel broadcast error occur! %d", ret)); @@ -287,9 +289,9 @@ void XPUElementwiseGrad(const framework::ExecutionContext& ctx, const framework::DDim& dx_dims = dx->dims(); std::pair, std::vector> reduce_v = XPUReducesAxisVector(out_dim, dx_dims); - ret = xpu::reduce_sum(dev_ctx.x_context(), dx_data, - dx->mutable_data(ctx.GetPlace()), reduce_v.first, - reduce_v.second); + ret = xpu::reduce_sum(dev_ctx.x_context(), dx_data, + dx->mutable_data(ctx.GetPlace()), + reduce_v.first, reduce_v.second); PADDLE_ENFORCE_EQ( ret, xpu::SUCCESS, platform::errors::External("XPU kernel reduce_sum occur error in " @@ -302,9 +304,9 @@ void XPUElementwiseGrad(const framework::ExecutionContext& ctx, const framework::DDim& dy_dims = dy->dims(); std::pair, std::vector> reduce_v = XPUReducesAxisVector(out_dim, dy_dims); - ret = xpu::reduce_sum(dev_ctx.x_context(), dy_data, - dy->mutable_data(ctx.GetPlace()), reduce_v.first, - reduce_v.second); + ret = xpu::reduce_sum(dev_ctx.x_context(), dy_data, + dy->mutable_data(ctx.GetPlace()), + reduce_v.first, reduce_v.second); PADDLE_ENFORCE_EQ( ret, xpu::SUCCESS, platform::errors::External("XPU kernel reduce_sum occur error in " diff --git a/paddle/fluid/operators/expand_op.h b/paddle/fluid/operators/expand_op.h index b3ff4ee1983..8b79a1feb8c 100644 --- a/paddle/fluid/operators/expand_op.h +++ b/paddle/fluid/operators/expand_op.h @@ -56,6 +56,12 @@ inline std::vector get_expand_times( TensorCopySync(*expand_tensor, platform::CPUPlace(), &cpu_expand_tensor); expand_data = cpu_expand_tensor.data(); } +#ifdef PADDLE_WITH_XPU + if (platform::is_xpu_place(expand_tensor->place())) { + TensorCopySync(*expand_tensor, platform::CPUPlace(), &cpu_expand_tensor); + expand_data = cpu_expand_tensor.data(); + } +#endif auto vec_epxand_times = std::vector(expand_data, expand_data + expand_tensor->numel()); return vec_epxand_times; @@ -72,7 +78,15 @@ inline std::vector get_expand_times( framework::Tensor temp; TensorCopySync(*tensor, platform::CPUPlace(), &temp); vec_epxand_times.push_back(*temp.data()); - } else { + } +#ifdef PADDLE_WITH_XPU + else if (platform::is_xpu_place(tensor->place())) { // NOLINT + framework::Tensor temp; + TensorCopySync(*tensor, platform::CPUPlace(), &temp); + vec_epxand_times.push_back(*temp.data()); + } +#endif + else { // NOLINT vec_epxand_times.push_back(*tensor->data()); } } diff --git a/paddle/fluid/operators/interpolate_op_xpu.cc b/paddle/fluid/operators/interpolate_op_xpu.cc new file mode 100644 index 00000000000..6dc42525469 --- /dev/null +++ b/paddle/fluid/operators/interpolate_op_xpu.cc @@ -0,0 +1,258 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include +#include +#include + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/interpolate_op.h" + +#ifdef PADDLE_WITH_XPU + +namespace paddle { +namespace operators { + +using framework::Tensor; +using DataLayout = framework::DataLayout; + +inline std::vector get_new_shape_xpu( + const std::vector& list_new_shape_tensor) { + // get tensor from + std::vector vec_new_shape; + for (size_t i = 0; i < list_new_shape_tensor.size(); ++i) { + auto tensor = list_new_shape_tensor[i]; + PADDLE_ENFORCE_EQ( + tensor->dims(), framework::make_ddim({1}), + platform::errors::InvalidArgument("shape of dim tensor should be [1]")); + if (platform::is_xpu_place(tensor->place())) { + framework::Tensor temp; + TensorCopySync(*tensor, platform::CPUPlace(), &temp); + vec_new_shape.push_back(static_cast(*temp.data())); + } else { + vec_new_shape.push_back(static_cast(*tensor->data())); + } + } + + return vec_new_shape; +} + +template +inline std::vector get_new_data_from_tensor_xpu( + const Tensor* new_data_tensor) { + std::vector vec_new_data; + auto* new_data = new_data_tensor->data(); + framework::Tensor cpu_starts_tensor; + if (platform::is_xpu_place(new_data_tensor->place())) { + TensorCopySync(*new_data_tensor, platform::CPUPlace(), &cpu_starts_tensor); + new_data = cpu_starts_tensor.data(); + } + vec_new_data = std::vector(new_data, new_data + new_data_tensor->numel()); + return vec_new_data; +} + +template +class InterpolateXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto* output = ctx.Output("Out"); + + auto input_dims = input->dims(); + PADDLE_ENFORCE_EQ( + input_dims.size(), 4, + platform::errors::External("XPU Interpolate kernel only support 2d")); + + const std::string data_layout_str = ctx.Attr("data_layout"); + const DataLayout data_layout = + framework::StringToDataLayout(data_layout_str); + int n, c, in_d, in_h, in_w; + ExtractNCDWH(input_dims, data_layout, &n, &c, &in_d, &in_h, &in_w); + + auto interp_method = ctx.Attr("interp_method"); + bool align_corners = ctx.Attr("align_corners"); + int align_mode = ctx.Attr("align_mode"); + + int out_h = ctx.Attr("out_h"); + int out_w = ctx.Attr("out_w"); + + auto list_new_size_tensor = ctx.MultiInput("SizeTensor"); + if (list_new_size_tensor.size() > 0) { + // have size tensor + auto new_size = get_new_shape_xpu(list_new_size_tensor); + out_h = new_size[0]; + out_w = new_size[1]; + } else { + float scale; + auto scale_tensor = ctx.Input("Scale"); + if (scale_tensor != nullptr) { + auto scale_data = get_new_data_from_tensor_xpu(scale_tensor); + scale = scale_data[0]; + } else { + scale = ctx.Attr("scale"); + } + if (scale > 0) { + out_h = static_cast(in_h * scale); + out_w = static_cast(in_w * scale); + } + auto out_size = ctx.Input("OutSize"); + if (out_size != nullptr) { + auto out_size_data = get_new_data_from_tensor_xpu(out_size); + out_h = out_size_data[0]; + out_w = out_size_data[1]; + } + } + PADDLE_ENFORCE_GT(out_h, 0, platform::errors::InvalidArgument( + "out_h in Attr(out_shape) of " + "Op(interpolate) " + "should be greater than 0.")); + PADDLE_ENFORCE_GT(out_w, 0, platform::errors::InvalidArgument( + "out_w in Attr(out_shape) of " + "Op(interpolate) " + "should be greater than 0.")); + framework::DDim dim_out; + if (data_layout == DataLayout::kNCHW) { + dim_out = {n, c, out_h, out_w}; + } else { + dim_out = {n, out_h, out_w, c}; + } + output->mutable_data(dim_out, ctx.GetPlace()); + + if (in_h == out_h && in_w == out_w) { + framework::TensorCopy(*input, ctx.GetPlace(), output); + return; + } + bool nearest = "nearest" == interp_method; + int trans_mode = (align_corners) ? (0) : ((align_mode == 0) ? (1) : (2)); + auto& dev_ctx = ctx.template device_context(); + if (nearest) { + PADDLE_ENFORCE_EQ((data_layout == DataLayout::kNCHW), true, + platform::errors::InvalidArgument( + "XPU nearest is only support NCHW")); + } + int r = xpu::interpolate2d(dev_ctx.x_context(), input->data(), + output->data(), n, c, in_h, in_w, + out_h, out_w, nearest, trans_mode, + (data_layout == DataLayout::kNCHW)); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, + platform::errors::External("XPU interpolate2d kernel " + "return wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); + } +}; + +template +class InterpolateGradXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input_grad = ctx.Output(framework::GradVarName("X")); + auto* output_grad = ctx.Input(framework::GradVarName("Out")); + + auto output_grad_dims = output_grad->dims(); + + PADDLE_ENFORCE_EQ(output_grad_dims.size(), 4, + platform::errors::External( + "XPU Interpolategrad kernel only support 2d")); + + auto* input = ctx.Input("X"); + const std::string data_layout_str = ctx.Attr("data_layout"); + const DataLayout data_layout = + framework::StringToDataLayout(data_layout_str); + int n, c, in_d, in_h, in_w; + ExtractNCDWH(input->dims(), data_layout, &n, &c, &in_d, &in_h, &in_w); + + auto interp_method = ctx.Attr("interp_method"); + bool align_corners = ctx.Attr("align_corners"); + int align_mode = ctx.Attr("align_mode"); + + int out_h = ctx.Attr("out_h"); + int out_w = ctx.Attr("out_w"); + float scale; + auto scale_tensor = ctx.Input("Scale"); + if (scale_tensor != nullptr) { + auto scale_data = get_new_data_from_tensor_xpu(scale_tensor); + scale = scale_data[0]; + } else { + scale = ctx.Attr("scale"); + } + if (scale > 0) { + out_h = static_cast(in_h * scale); + out_w = static_cast(in_w * scale); + } + auto out_size = ctx.Input("OutSize"); + if (out_size != nullptr) { + auto out_size_data = get_new_data_from_tensor_xpu(out_size); + out_h = out_size_data[0]; + out_w = out_size_data[1]; + } + auto list_new_size_tensor = ctx.MultiInput("SizeTensor"); + if (list_new_size_tensor.size() > 0) { + // have size tensor + auto new_size = get_new_shape_xpu(list_new_size_tensor); + out_h = new_size[0]; + out_w = new_size[1]; + } + + framework::DDim dim_grad; + if (data_layout == DataLayout::kNCHW) { + dim_grad = {n, c, in_h, in_w}; + } else { + dim_grad = {n, in_h, in_w, c}; + } + input_grad->mutable_data(dim_grad, ctx.GetPlace()); + + auto& dev_ctx = ctx.template device_context(); + + int r = XPU_SUCCESS; + r = xpu::constant(dev_ctx.x_context(), input_grad->data(), + input_grad->numel(), static_cast(0.0)); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, + platform::errors::External( + "XPU constant in interpolate2d_grad kernel return " + "wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); + + if (in_h == out_h && in_w == out_w) { + framework::TensorCopy(*output_grad, ctx.GetPlace(), input_grad); + return; + } + + bool nearest = "nearest" == interp_method; + int trans_mode = (align_corners) ? (0) : ((align_mode == 0) ? (1) : (2)); + + if (nearest) { + PADDLE_ENFORCE_EQ((data_layout == DataLayout::kNCHW), true, + platform::errors::InvalidArgument( + "XPU nearest is only support NCHW")); + } + + r = xpu::interpolate2d_grad(dev_ctx.x_context(), output_grad->data(), + input_grad->data(), n, c, in_h, in_w, + out_h, out_w, nearest, trans_mode, + (data_layout == DataLayout::kNCHW)); + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External("XPU interpolate2d_grad kernel return " + "wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_XPU_KERNEL(bilinear_interp, ops::InterpolateXPUKernel); + +REGISTER_OP_XPU_KERNEL(bilinear_interp_grad, + ops::InterpolateGradXPUKernel); +#endif diff --git a/paddle/fluid/operators/pool_op_xpu.cc b/paddle/fluid/operators/pool_op_xpu.cc index 325b7359389..096a81db9bd 100644 --- a/paddle/fluid/operators/pool_op_xpu.cc +++ b/paddle/fluid/operators/pool_op_xpu.cc @@ -43,16 +43,18 @@ class PoolXPUKernel : public framework::OpKernel { bool exclusive = context.Attr("exclusive"); bool is_test = context.Attr("is_test"); bool adaptive = context.Attr("adaptive"); - PADDLE_ENFORCE_EQ( - !adaptive, true, - platform::errors::InvalidArgument( - "The Pool2d XPU OP does not support adaptive == true!")); PADDLE_ENFORCE_EQ( ksize.size(), 2, platform::errors::InvalidArgument( "The Pool2d XPU OP only support 2 dimension pooling!")); + PADDLE_ENFORCE_EQ(!adaptive || (ksize[0] * ksize[1] == 1), true, + platform::errors::InvalidArgument( + "The Pool2d XPU OP does not support (adaptive == " + "true && output_size != 1)")); int* index_data = nullptr; - if (context.Attr("global_pooling")) { + bool global_pooling = context.Attr("global_pooling") || + (adaptive && (ksize[0] * ksize[1] == 1)); + if (global_pooling) { for (size_t i = 0; i < ksize.size(); ++i) { paddings[i] = 0; ksize[i] = static_cast(in_x->dims()[i + 2]); @@ -104,16 +106,18 @@ class PoolGradXPUKernel : public framework::OpKernel { bool exclusive = context.Attr("exclusive"); bool adaptive = context.Attr("adaptive"); const int* index_data = nullptr; - PADDLE_ENFORCE_EQ( - !adaptive, true, - platform::errors::InvalidArgument( - "The Pool2d XPU OP does not support adaptive == true!")); PADDLE_ENFORCE_EQ(ksize.size(), 2, platform::errors::InvalidArgument( "The Pool2d XPU OP only support 2 " "dimension pooling!, but received " "%d-dimension pool kernel size", ksize.size())); - if (context.Attr("global_pooling")) { + PADDLE_ENFORCE_EQ(!adaptive || (ksize[0] * ksize[1] == 1), true, + platform::errors::InvalidArgument( + "The Pool2d XPU OP does not support (adaptive == " + "true && output_size != 1)")); + bool global_pooling = context.Attr("global_pooling") || + (adaptive && (ksize[0] * ksize[1] == 1)); + if (global_pooling) { for (size_t i = 0; i < ksize.size(); ++i) { paddings[i] = 0; ksize[i] = static_cast(in_x->dims()[i + 2]); diff --git a/paddle/fluid/operators/reduce_ops/reduce_sum_op_xpu.cc b/paddle/fluid/operators/reduce_ops/reduce_sum_op_xpu.cc index b751eca9ee0..f67d43194a0 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_sum_op_xpu.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_sum_op_xpu.cc @@ -16,6 +16,8 @@ #include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h" #include #include +#include "paddle/fluid/platform/xpu_header.h" + namespace paddle { namespace operators { @@ -27,86 +29,120 @@ class ReduceSumXPUKernel : public framework::OpKernel { platform::is_xpu_place(context.GetPlace()), true, platform::errors::Unavailable("This kernel only runs on XPU.")); bool reduce_all = context.Attr("reduce_all"); - auto* input = context.Input("X"); - auto* output = context.Output("Out"); - output->mutable_data(context.GetPlace()); + auto dims = context.Attr>("dim"); + auto* x = context.Input("X"); + auto* y = context.Output("Out"); + y->mutable_data(context.GetPlace()); auto& dev_ctx = context.template device_context(); + + int out_dtype = context.Attr("out_dtype"); + PADDLE_ENFORCE_EQ( + out_dtype == -1, true, + platform::errors::InvalidArgument( + "XPU only support out_dtype == -1 in reduce_sum op.")); + + const auto* x_data = x->data(); + auto* y_data = y->data(); + const auto& input_dim_size = x->dims().size(); + std::vector true_dims; + for (size_t i = 0; i < dims.size(); ++i) { + if (dims[i] < 0) { + true_dims.push_back(dims[i] + input_dim_size); + } else { + true_dims.push_back(dims[i]); + } + } + + std::vector reduce_dims; + std::vector xdims((input_dim_size)); + for (int i = 0; i < input_dim_size; ++i) { + xdims[i] = x->dims()[i]; + } if (reduce_all) { - int input_len = input->numel(); - int r = xpu::sum(dev_ctx.x_context(), input->data(), output->data(), - input_len); - PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true, - platform::errors::External("XPU kernel error!")); + for (int i = 0; i < input_dim_size; ++i) { + reduce_dims.push_back(i); + } } else { - int ndim = input->dims().size(); - std::vector idims; - for (int i = 0; i < input->dims().size(); i++) { - idims.push_back(input->dims()[i]); + std::set dims_set(true_dims.begin(), true_dims.end()); + for (auto i = 0; i < input_dim_size; i++) { + if (dims_set.find(i) != dims_set.end()) { + if (x->dims()[i] != 1) { + reduce_dims.push_back(i); + } + } } - auto dims = context.Attr>("dim"); - int rdim = dims.size(); - int r = - xpu::reduce(dev_ctx.x_context(), input->data(), output->data(), - idims.data(), ndim, dims.data(), rdim, xpu::REDUCE_SUM); - PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true, - platform::errors::External("XPU kernel error!")); + } + + if (reduce_dims.size() == 0) { + int r = xpu::copy(dev_ctx.x_context(), x_data, y_data, + x->numel() * sizeof(T)); + PADDLE_ENFORCE_EQ( + r == xpu::Error_t::SUCCESS, true, + platform::errors::External("XPU copy in reduce_sum op return " + "wrong value[%d %s].", + r, XPUAPIErrorMsg[r])); + } else { + int r = xpu::reduce_sum(dev_ctx.x_context(), x_data, y_data, xdims, + reduce_dims); + PADDLE_ENFORCE_EQ( + r == xpu::Error_t::SUCCESS, true, + platform::errors::External("XPU reduce_sum in reduce_sum op return" + " wrong value[%d %s].", + r, XPUAPIErrorMsg[r])); } } }; + template class ReduceSumGradXPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto dims = context.Attr>("dim"); bool reduce_all = context.Attr("reduce_all"); - auto* input0 = context.Input("X"); - auto* input2 = context.Input(framework::GradVarName("Out")); - auto* output = context.Output(framework::GradVarName("X")); - output->mutable_data(context.GetPlace()); - const auto* input2_d = input2->data(); - auto* output_d = output->data(); + auto* x = context.Input("X"); + auto* out = context.Input(framework::GradVarName("Out")); + auto* x_grad = context.Output(framework::GradVarName("X")); + + int in_dtype = context.Attr("in_dtype"); + PADDLE_ENFORCE_EQ( + in_dtype == -1, true, + platform::errors::InvalidArgument( + "XPU only support in_dtype == -1 in reduce_sum_grad op.")); + auto& dev_ctx = context.template device_context(); - int r = 0; - std::vector idims; - int reduce_dim = 0; - if (reduce_all) { - idims.push_back(input0->numel()); - idims.push_back(1); - idims.push_back(1); - r = xpu::reduce_grad(dev_ctx.x_context(), input2_d, output_d, - idims.data(), idims.size(), &reduce_dim, 1, - xpu::REDUCE_SUM); - PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true, - platform::errors::External("XPU kernel error!")); - } else if (dims.size() == 1) { - // handle reduce by one dimension - int reduce_dim_index = dims[0]; - if (reduce_dim_index < 0) { - reduce_dim_index += input0->dims().size(); - } - auto& input_dim = input0->dims(); - int before_dim = 1; - for (int i = 0; i < reduce_dim_index; ++i) { - before_dim *= input_dim[i]; + x_grad->mutable_data(context.GetPlace()); + const auto* out_data = out->data(); + auto* x_grad_data = x_grad->data(); + + const auto& input_dim_size = x->dims().size(); + std::vector true_dims; + for (size_t i = 0; i < dims.size(); ++i) { + if (dims[i] < 0) { + true_dims.push_back(dims[i] + input_dim_size); + } else { + true_dims.push_back(dims[i]); } - int reduce_dim = input_dim[reduce_dim_index]; - int after_dim = 1; - for (int i = reduce_dim_index + 1; i < input_dim.size(); ++i) { - after_dim *= input_dim[i]; + } + + std::vector ydims(input_dim_size); + std::vector xdims((input_dim_size)); + std::set dims_set(true_dims.begin(), true_dims.end()); + for (auto i = 0; i < input_dim_size; i++) { + xdims[i] = x->dims()[i]; + if (dims_set.find(i) != dims_set.end() || reduce_all) { + ydims[i] = 1; + } else { + ydims[i] = x->dims()[i]; } - idims.push_back(before_dim); - idims.push_back(input_dim[reduce_dim_index]); - idims.push_back(after_dim); - reduce_dim = 1; - r = xpu::reduce_grad(dev_ctx.x_context(), input2_d, output_d, - idims.data(), idims.size(), &reduce_dim, 1, - xpu::REDUCE_SUM); - PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true, - platform::errors::External("XPU kernel error!")); - } else { - PADDLE_THROW( - platform::errors::Unimplemented("unsupport reduce sum grad")); } + + int r = xpu::broadcast(dev_ctx.x_context(), out_data, x_grad_data, ydims, + xdims); + PADDLE_ENFORCE_EQ( + r == xpu::Error_t::SUCCESS, true, + platform::errors::External("XPU broadcast in reduce_sum_grad op return" + " wrong value[%d %s].", + r, XPUAPIErrorMsg[r])); } }; diff --git a/paddle/fluid/operators/softmax_op_xpu.cc b/paddle/fluid/operators/softmax_op_xpu.cc index 29740000aeb..312c5d2dde1 100644 --- a/paddle/fluid/operators/softmax_op_xpu.cc +++ b/paddle/fluid/operators/softmax_op_xpu.cc @@ -1,11 +1,8 @@ /* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -30,29 +27,27 @@ class SoftmaxXPUKernel : public framework::OpKernel { auto* x = context.Input("X"); auto* out = context.Output("Out"); const int rank = x->dims().size(); - const int axis = CanonicalAxis(context.Attr("axis"), rank); - PADDLE_ENFORCE_EQ(axis == -1 || axis == rank - 1, true, - platform::errors::InvalidArgument( - "xpu softmax kernel only support last dimension of x " - "(axis==-1 or axis==x_dims-1), but received axis: " - "%d, x's shape: %s.", - axis, x->dims())); + int axis = CanonicalAxis(context.Attr("axis"), rank); // allocate memory on device. out->mutable_data(context.GetPlace()); - const int n = SizeToAxis(axis, x->dims()); - const int d = SizeFromAxis(axis, x->dims()); + std::vector x_dims; + for (int i = 0; i < rank; i++) { + x_dims.push_back(x->dims()[i]); + } + if (axis < 0) { + axis += rank; + } auto& dev_ctx = context.template device_context(); - int r = xpu::softmax2d_forward(dev_ctx.x_context(), x->data(), - out->data(), n, d, d <= 2048); + int r = xpu::softmax(dev_ctx.x_context(), x->data(), + out->data(), x_dims, axis); PADDLE_ENFORCE_EQ( r, XPU_SUCCESS, platform::errors::External("XPU API(softmax2d_forward) return wrong " - "value[%d], please check whether " - "Baidu Kunlun Card is properly installed.", - r)); + "value[%d %s]", + r, XPUAPIErrorMsg[r])); } }; @@ -64,24 +59,28 @@ class SoftmaxGradXPUKernel : public framework::OpKernel { auto* dout = context.Input(framework::GradVarName("Out")); auto* dx = context.Output(framework::GradVarName("X")); const int rank = dx->dims().size(); - const int axis = CanonicalAxis(context.Attr("axis"), rank); + int axis = CanonicalAxis(context.Attr("axis"), rank); // allocate memory on device. dx->mutable_data(context.GetPlace()); - const int n = SizeToAxis(axis, dx->dims()); - const int d = SizeFromAxis(axis, dx->dims()); + std::vector x_dims; + for (int i = 0; i < rank; i++) { + x_dims.push_back(dx->dims()[i]); + } + if (axis < 0) { + axis += rank; + } auto& dev_ctx = context.template device_context(); - int r = - xpu::softmax2d_backward(dev_ctx.x_context(), out->data(), - dout->data(), dx->data(), n, d); + int r = xpu::softmax_grad(dev_ctx.x_context(), out->data(), + dout->data(), dx->data(), x_dims, + axis); PADDLE_ENFORCE_EQ( r, XPU_SUCCESS, platform::errors::External("XPU API(softmax2d_backward) return wrong " - "value[%d], please check whether " - "Baidu Kunlun Card is properly installed.", - r)); + "value[%d %s]", + r, XPUAPIErrorMsg[r])); } }; diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op_xpu.cc b/paddle/fluid/operators/softmax_with_cross_entropy_op_xpu.cc index 368a12057c8..346ed965d06 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op_xpu.cc +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op_xpu.cc @@ -70,7 +70,8 @@ class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel { r)); } else { Tensor labels_int32; - labels_int32.mutable_data(context.GetPlace(), labels->numel()); + labels_int32.mutable_data(context.GetPlace(), + labels->numel() * sizeof(int32_t)); r = xpu::cast_v2( dev_ctx.x_context(), labels->data(), labels_int32.data(), labels->numel()); diff --git a/paddle/fluid/operators/transpose_op_xpu.cc b/paddle/fluid/operators/transpose_op_xpu.cc index c7ecf2ebfaa..2748c07f9e6 100644 --- a/paddle/fluid/operators/transpose_op_xpu.cc +++ b/paddle/fluid/operators/transpose_op_xpu.cc @@ -17,105 +17,27 @@ limitations under the License. */ #include #include #include +#include "paddle/fluid/platform/xpu_header.h" namespace paddle { namespace operators { using framework::Tensor; -bool XPUSupported(int ndims, const std::vector& axis) { - /* - * XPU currently support: - * permute = {0, 2, 1}, permute = {1, 0}, - * permute = {0, 2, 1, 3}, permute = {1, 0, 2}, - * permute = {0, 2, 3, 1} - */ - bool is_supported = false; - std::vector permute_10(2, 0); - std::vector permute_102(3, 0); - std::vector permute_021(3, 0); - std::vector permute_210(3, 0); - std::vector permute_0213(4, 0); - std::vector permute_0231(4, 0); - std::vector permute_0312(4, 0); - std::vector permute_3201(4, 0); - permute_10[0] = 1; - permute_102[0] = 1; - permute_102[2] = 2; - permute_021[1] = 2; - permute_021[2] = 1; - permute_210[0] = 2; - permute_210[1] = 1; - permute_0213[1] = 2; - permute_0213[2] = 1; - permute_0213[3] = 3; - permute_0231[1] = 2; - permute_0231[2] = 3; - permute_0231[3] = 1; - permute_0312[1] = 3; - permute_0312[2] = 1; - permute_0312[3] = 2; - permute_3201[0] = 3; - permute_3201[1] = 2; - permute_3201[3] = 1; - switch (ndims) { - case 2: - if (axis == permute_10) { - is_supported = true; - } - break; - case 3: - if ((axis == permute_021) || (axis == permute_102) || - (axis == permute_210)) { - is_supported = true; - } - break; - case 4: - if ((axis == permute_0213) || (axis == permute_0231) || - (axis == permute_0312) || (axis == permute_3201)) { - is_supported = true; - } - break; - default: - PADDLE_THROW(platform::errors::Unimplemented( - "Tensors with rank only 2, 3 and 4 are supported on XPU")); - } - return is_supported; -} - template class TransposeXPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto x = context.Input("X"); auto out = context.Output("Out"); + // axis is permute auto axis = context.Attr>("axis"); int ndims = axis.size(); const auto x_dims = x->dims(); - const T* x_data = x->data(); T* y_data = out->mutable_data(context.GetPlace()); - if (!XPUSupported(ndims, axis)) { - VLOG(0) << "XPU does not support the permute, try to do on cpu"; - framework::Tensor x_cpu; - framework::Tensor out_cpu; - auto x_cpu_data = x_cpu.mutable_data(x->dims(), platform::CPUPlace()); - auto out_cpu_data = - out_cpu.mutable_data(out->dims(), platform::CPUPlace()); - memory::Copy(platform::CPUPlace(), reinterpret_cast(x_cpu_data), - BOOST_GET_CONST(platform::XPUPlace, context.GetPlace()), - (const void*)x_data, x->numel() * sizeof(T)); - - const platform::CPUDeviceContext* cpu_dev_ctx = - static_cast( - platform::DeviceContextPool::Instance().Get( - platform::CPUPlace())); - TransCompute(ndims, *cpu_dev_ctx, x_cpu, - &out_cpu, axis); - memory::Copy(BOOST_GET_CONST(platform::XPUPlace, context.GetPlace()), - reinterpret_cast(y_data), platform::CPUPlace(), - (const void*)out_cpu_data, out->numel() * sizeof(T)); + if (out->numel() == 0) { return; } @@ -123,10 +45,9 @@ class TransposeXPUKernel : public framework::OpKernel { for (int i = 0; i < ndims; ++i) { x_shape_host[i] = x_dims[i]; } - int* permute_host = axis.data(); auto& dev_ctx = context.template device_context(); - int r = xpu::transpose(dev_ctx.x_context(), x_data, y_data, - x_shape_host.data(), permute_host, ndims); + int r = xpu::transpose(dev_ctx.x_context(), x_data, y_data, x_shape_host, + axis); PADDLE_ENFORCE_EQ( r, xpu::Error_t::SUCCESS, platform::errors::External("XPU kernel error! error code=%d", r)); @@ -151,20 +72,13 @@ class TransposeGradXPUKernel : public framework::OpKernel { } int ndims = axis.size(); - if (!XPUSupported(ndims, reversed_axis)) { - PADDLE_THROW( - platform::errors::Unimplemented("XPU does not support the permute")); - } - std::vector out_shape_host(ndims, 0); for (int i = 0; i < ndims; ++i) { out_shape_host[i] = out_grad->dims()[i]; } - int* permute_host = reversed_axis.data(); auto& dev_ctx = context.template device_context(); - int r = xpu::transpose(dev_ctx.x_context(), out_grad->data(), - x_grad->data(), out_shape_host.data(), - permute_host, ndims); + int r = xpu::transpose(dev_ctx.x_context(), out_grad->data(), + x_grad->data(), out_shape_host, reversed_axis); PADDLE_ENFORCE_EQ( r, xpu::Error_t::SUCCESS, platform::errors::External("XPU kernel error! error code=%d", r)); diff --git a/paddle/fluid/operators/uniform_random_op_xpu.cc b/paddle/fluid/operators/uniform_random_op_xpu.cc index 507bd7e9ea9..d8b82ad5f86 100644 --- a/paddle/fluid/operators/uniform_random_op_xpu.cc +++ b/paddle/fluid/operators/uniform_random_op_xpu.cc @@ -29,37 +29,68 @@ class XPUUniformRandomKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext &ctx) const override { framework::Tensor *tensor = nullptr; auto out_var = ctx.OutputVar("Out"); - if (out_var->IsType()) { - tensor = out_var->GetMutable(); - } else if (out_var->IsType()) { - auto shape = ctx.Attr>("shape"); + std::vector new_shape; + auto list_new_shape_tensor = + ctx.MultiInput("ShapeTensorList"); + if (list_new_shape_tensor.size() > 0 || ctx.HasInput("ShapeTensor")) { + if (ctx.HasInput("ShapeTensor")) { + auto *shape_tensor = ctx.Input("ShapeTensor"); + new_shape = GetNewDataFromShapeTensor(shape_tensor); + } else if (list_new_shape_tensor.size() > 0) { + new_shape = GetNewDataFromShapeTensorList(list_new_shape_tensor); + } + } + + if (out_var->IsType()) { auto *selected_rows = out_var->GetMutable(); tensor = selected_rows->mutable_value(); + auto shape = ctx.Attr>("shape"); + if (!new_shape.empty()) shape = new_shape; tensor->Resize(framework::make_ddim(shape)); selected_rows->mutable_rows()->reserve(shape[0]); + } else if (out_var->IsType()) { + tensor = out_var->GetMutable(); + if (!new_shape.empty()) tensor->Resize(framework::make_ddim(new_shape)); } else { PADDLE_THROW(platform::errors::InvalidArgument( - "Expected type of Output(out) in uniform_random_op must be " - "LoDTensor, " - "SelectedRows. But got unsupport type: %s.", + "Expected type of Output(out) in uniform_random_op must be Tensor, " + "SelectedRows. But got " + "unsupport type: %s.", framework::ToTypeName(out_var->Type()))); } T *data = tensor->mutable_data(ctx.GetPlace()); int64_t size = tensor->numel(); + std::unique_ptr data_cpu(new T[size]); std::uniform_real_distribution dist( static_cast(ctx.Attr("min")), static_cast(ctx.Attr("max"))); unsigned int seed = static_cast(ctx.Attr("seed")); - // TODO(pangyoki): implement GetXPURandomEngine to set different seeds on - // corresponding XPU device. auto engine = framework::GetCPURandomEngine(seed); - std::unique_ptr data_cpu(new T[size]); for (int64_t i = 0; i < size; ++i) { data_cpu[i] = dist(*engine); } + unsigned int diag_num = + static_cast(ctx.Attr("diag_num")); + unsigned int diag_step = + static_cast(ctx.Attr("diag_step")); + auto diag_val = static_cast(ctx.Attr("diag_val")); + if (diag_num > 0) { + PADDLE_ENFORCE_GT( + size, (diag_num - 1) * (diag_step + 1), + platform::errors::InvalidArgument( + "ShapeInvalid: the diagonal's elements is equal (num-1) " + "* (step-1) with num %d, step %d," + "It should be smaller than %d, but received %d", + diag_num, diag_step, (diag_num - 1) * (diag_step + 1), size)); + for (int64_t i = 0; i < diag_num; ++i) { + int64_t pos = i * diag_step + i; + data_cpu[pos] = diag_val; + } + } + memory::Copy(BOOST_GET_CONST(platform::XPUPlace, ctx.GetPlace()), data, platform::CPUPlace(), reinterpret_cast(data_cpu.get()), size * sizeof(T)); diff --git a/paddle/fluid/platform/xpu_header.h b/paddle/fluid/platform/xpu_header.h index 66982769837..98bd019ad96 100644 --- a/paddle/fluid/platform/xpu_header.h +++ b/paddle/fluid/platform/xpu_header.h @@ -15,11 +15,13 @@ #pragma once #ifdef PADDLE_WITH_XPU +#include #include #include #include "paddle/fluid/platform/errors.h" #include "xpu/api.h" +#include "xpu/refactor/math.h" #include "xpu/refactor/nn.h" #include "xpu/runtime.h" #include "xpu/runtime_ex.h" @@ -48,4 +50,11 @@ class XPUActHelper { return res->second; } }; + +static std::map XPUAPIErrorMsg = { + {xpu::Error_t::SUCCESS, "xpu api success"}, + {xpu::Error_t::INVALID_PARAM, "xpu api invalid param"}, + {xpu::Error_t::RUNTIME_ERROR, "xpu api runtime error"}, + {xpu::Error_t::NO_ENOUGH_WORKSPACE, "xpu api no enough workspace"}}; + #endif diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 215b4cd039f..fdd236a58f0 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -1915,6 +1915,10 @@ def load(program, model_path, executor=None, var_list=None): place = paddle.fluid.CPUPlace() elif p.is_cuda_pinned_place(): place = paddle.fluid.CUDAPinnedPlace() + elif p.is_xpu_place(): + p = paddle.fluid.core.Place() + p.set_place(t._place()) + place = paddle.fluid.XPUPlace(p.xpu_device_id()) else: p = paddle.fluid.core.Place() p.set_place(t._place()) diff --git a/python/paddle/fluid/tests/unittests/op_test_xpu.py b/python/paddle/fluid/tests/unittests/op_test_xpu.py index 7e19d8e4d8a..37b446174d6 100644 --- a/python/paddle/fluid/tests/unittests/op_test_xpu.py +++ b/python/paddle/fluid/tests/unittests/op_test_xpu.py @@ -362,17 +362,6 @@ class XPUOpTest(OpTest): if not type(output_names) is list: output_names = [output_names] - numeric_grads = user_defined_grads or [ - get_numeric_gradient( - place, - self.scope, - self.op, - self.inputs, - input_to_check, - output_names, - delta=numeric_grad_delta, - in_place=in_place) for input_to_check in inputs_to_check - ] analytic_grads = self._get_gradient(inputs_to_check, place, output_names, no_grad_set) return analytic_grads diff --git a/python/paddle/fluid/tests/unittests/xpu/test_bilinear_interp_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_bilinear_interp_op_xpu.py new file mode 100755 index 00000000000..f8ae945b6eb --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_bilinear_interp_op_xpu.py @@ -0,0 +1,519 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import unittest +import numpy as np +import paddle +import paddle.fluid.core as core +import sys +sys.path.append("..") +from op_test_xpu import XPUOpTest +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard +import time + +paddle.enable_static() + + +def bilinear_interp_np(input, + out_h, + out_w, + out_size=None, + actual_shape=None, + align_corners=True, + align_mode=0, + data_layout='NCHW'): + """bilinear interpolation implement in shape [N, C, H, W]""" + if data_layout == "NHWC": + input = np.transpose(input, (0, 3, 1, 2)) # NHWC => NCHW + if out_size is not None: + out_h = out_size[0] + out_w = out_size[1] + if actual_shape is not None: + out_h = actual_shape[0] + out_w = actual_shape[1] + batch_size, channel, in_h, in_w = input.shape + + ratio_h = ratio_w = 0.0 + if out_h > 1: + if (align_corners): + ratio_h = (in_h - 1.0) / (out_h - 1.0) + else: + ratio_h = 1.0 * in_h / out_h + if out_w > 1: + if (align_corners): + ratio_w = (in_w - 1.0) / (out_w - 1.0) + else: + ratio_w = 1.0 * in_w / out_w + + out = np.zeros((batch_size, channel, out_h, out_w)) + + for i in range(out_h): + if (align_mode == 0 and not align_corners): + h = int(ratio_h * (i + 0.5) - 0.5) + else: + h = int(ratio_h * i) + + h = max(0, h) + hid = 1 if h < in_h - 1 else 0 + if (align_mode == 0 and not align_corners): + idx_src_h = max(ratio_h * (i + 0.5) - 0.5, 0) + h1lambda = idx_src_h - h + else: + h1lambda = ratio_h * i - h + h2lambda = 1.0 - h1lambda + for j in range(out_w): + if (align_mode == 0 and not align_corners): + w = int(ratio_w * (j + 0.5) - 0.5) + else: + w = int(ratio_w * j) + w = max(0, w) + wid = 1 if w < in_w - 1 else 0 + if (align_mode == 0 and not align_corners): + idx_src_w = max(ratio_w * (j + 0.5) - 0.5, 0) + w1lambda = idx_src_w - w + else: + w1lambda = ratio_w * j - w + w2lambda = 1.0 - w1lambda + + out[:, :, i, j] = h2lambda*(w2lambda*input[:, :, h, w] + + w1lambda*input[:, :, h, w+wid]) + \ + h1lambda*(w2lambda*input[:, :, h+hid, w] + + w1lambda*input[:, :, h+hid, w+wid]) + + if data_layout == "NHWC": + out = np.transpose(out, (0, 2, 3, 1)) # NCHW => NHWC + + return out.astype(input.dtype) + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestBilinearInterpOp(XPUOpTest): + def setUp(self): + self.use_xpu = True + self.out_size = None + self.actual_shape = None + self.data_layout = 'NCHW' + self.init_test_case() + self.op_type = "bilinear_interp" + input_np = np.random.random(self.input_shape).astype("float32") + + if self.data_layout == "NCHW": + in_h = self.input_shape[2] + in_w = self.input_shape[3] + else: + in_h = self.input_shape[1] + in_w = self.input_shape[2] + + if self.scale > 0: + out_h = int(in_h * self.scale) + out_w = int(in_w * self.scale) + else: + out_h = self.out_h + out_w = self.out_w + + output_np = bilinear_interp_np(input_np, out_h, out_w, self.out_size, + self.actual_shape, self.align_corners, + self.align_mode, self.data_layout) + self.inputs = {'X': input_np} + if self.out_size is not None: + self.inputs['OutSize'] = self.out_size + if self.actual_shape is not None: + self.inputs['OutSize'] = self.actual_shape + + self.attrs = { + 'out_h': self.out_h, + 'out_w': self.out_w, + 'scale': self.scale, + 'interp_method': self.interp_method, + 'align_corners': self.align_corners, + 'align_mode': self.align_mode, + 'data_layout': self.data_layout + } + self.outputs = {'Out': output_np} + + def test_check_output(self): + place = paddle.XPUPlace(0) + self.check_output_with_place(place) + + def test_check_grad(self): + place = paddle.XPUPlace(0) + self.check_grad_with_place(place, ['X'], 'Out', in_place=True) + + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [2, 3, 5, 5] + self.out_h = 2 + self.out_w = 2 + self.scale = 0. + self.out_size = np.array([3, 3]).astype("int32") + self.align_corners = True + self.align_mode = 1 + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestBilinearInterpCase1(TestBilinearInterpOp): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [4, 1, 7, 8] + self.out_h = 1 + self.out_w = 1 + self.scale = 0. + self.align_corners = True + self.align_mode = 1 + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestBilinearInterpCase2(TestBilinearInterpOp): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [3, 3, 9, 6] + self.out_h = 12 + self.out_w = 12 + self.scale = 0. + self.align_corners = True + self.align_mode = 1 + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestBilinearInterpCase3(TestBilinearInterpOp): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [1, 1, 32, 64] + self.out_h = 64 + self.out_w = 32 + self.scale = 0. + self.align_corners = True + self.align_mode = 1 + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestBilinearInterpCase4(TestBilinearInterpOp): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [4, 1, 7, 8] + self.out_h = 1 + self.out_w = 1 + self.scale = 0. + self.out_size = np.array([2, 2]).astype("int32") + self.align_corners = True + self.align_mode = 1 + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestBilinearInterpCase5(TestBilinearInterpOp): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [3, 3, 9, 6] + self.out_h = 12 + self.out_w = 12 + self.scale = 0. + self.out_size = np.array([11, 11]).astype("int32") + self.align_corners = True + self.align_mode = 1 + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestBilinearInterpCase6(TestBilinearInterpOp): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [1, 1, 32, 64] + self.out_h = 64 + self.out_w = 32 + self.scale = 0. + self.out_size = np.array([65, 33]).astype("int32") + self.align_corners = True + self.align_mode = 1 + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestBilinearInterpSame(TestBilinearInterpOp): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [2, 3, 32, 64] + self.out_h = 32 + self.out_w = 64 + self.scale = 0. + self.align_corners = True + self.align_mode = 1 + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestBilinearInterpActualShape(TestBilinearInterpOp): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [3, 2, 32, 16] + self.out_h = 64 + self.out_w = 32 + self.scale = 0. + self.out_size = np.array([66, 40]).astype("int32") + self.align_corners = True + self.align_mode = 1 + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestBilinearInterpDataLayout(TestBilinearInterpOp): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [2, 5, 5, 3] + self.out_h = 2 + self.out_w = 2 + self.scale = 0. + self.out_size = np.array([3, 3]).astype("int32") + self.align_corners = True + self.align_mode = 1 + self.data_layout = "NHWC" + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestBilinearInterpOtherMethod1(TestBilinearInterpOp): + def set_align_mode(self): + self.align_corners = False + self.align_mode = 1 + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestBilinearInterpWithMethod2(TestBilinearInterpOp): + def set_align_mode(self): + self.align_corners = False + self.align_mode = 0 + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestBilinearInterpWithMethod3(TestBilinearInterpOp): + def set_align_mode(self): + self.align_corners = True + self.align_mode = 0 + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestBilinearInterpScale1(TestBilinearInterpOp): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [2, 3, 5, 7] + self.out_h = 60 + self.out_w = 25 + self.scale = 2. + self.align_corners = True + self.align_mode = 1 + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestBilinearInterpScale2(TestBilinearInterpOp): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [2, 3, 5, 7] + self.out_h = 60 + self.out_w = 25 + self.scale = 1. + self.align_corners = True + self.align_mode = 1 + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestBilinearInterpScale3(TestBilinearInterpOp): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [2, 3, 5, 7] + self.out_h = 60 + self.out_w = 25 + self.scale = 1.5 + self.align_corners = True + self.align_mode = 1 + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestBilinearInterpZero(TestBilinearInterpOp): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [2, 3, 5, 7] + self.out_h = 60 + self.out_w = 25 + self.scale = 0.2 + self.align_corners = False + self.align_mode = 0 + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestBilinearInterpOp_attr_tensor(XPUOpTest): + def setUp(self): + self.out_size = None + self.actual_shape = None + self.init_test_case() + self.op_type = "bilinear_interp" + self.shape_by_1Dtensor = False + self.scale_by_1Dtensor = False + self.attrs = { + 'interp_method': self.interp_method, + 'align_corners': self.align_corners, + } + + input_np = np.random.random(self.input_shape).astype("float32") + self.inputs = {'X': input_np} + + if self.scale_by_1Dtensor: + self.inputs['Scale'] = np.array([self.scale]).astype("float32") + elif self.scale > 0: + out_h = int(self.input_shape[2] * self.scale) + out_w = int(self.input_shape[3] * self.scale) + self.attrs['scale'] = self.scale + else: + out_h = self.out_h + out_w = self.out_w + + if self.shape_by_1Dtensor: + self.inputs['OutSize'] = self.out_size + elif self.out_size is not None: + size_tensor = [] + for index, ele in enumerate(self.out_size): + size_tensor.append(("x" + str(index), np.ones( + (1)).astype('int32') * ele)) + self.inputs['SizeTensor'] = size_tensor + + self.attrs['out_h'] = self.out_h + self.attrs['out_w'] = self.out_w + output_np = bilinear_interp_np(input_np, out_h, out_w, self.out_size, + self.actual_shape, self.align_corners) + self.outputs = {'Out': output_np} + + def test_check_output(self): + place = paddle.XPUPlace(0) + self.check_output_with_place(place) + + def test_check_grad(self): + place = paddle.XPUPlace(0) + self.check_grad_with_place(place, ['X'], 'Out', in_place=True) + + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [2, 3, 5, 5] + self.out_h = 3 + self.out_w = 3 + self.scale = 0. + self.out_size = [3, 3] + self.align_corners = True + + +# out_size is a 1-D tensor +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestBilinearInterp_attr_tensor_Case1(TestBilinearInterpOp_attr_tensor): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [3, 3, 9, 6] + self.out_h = 12 + self.out_w = 12 + self.scale = 0. + self.out_size = [8, 12] + self.align_corners = True + + +# scale is a 1-D tensor +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestBilinearInterp_attr_tensor_Case2(TestBilinearInterpOp_attr_tensor): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [3, 2, 32, 16] + self.out_h = 64 + self.out_w = 32 + self.scale = 0. + self.out_size = np.array([66, 40]).astype("int32") + self.align_corners = True + self.shape_by_1Dtensor = True + + +# scale is a 1-D tensor +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestBilinearInterp_attr_tensor_Case3(TestBilinearInterpOp_attr_tensor): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [3, 2, 32, 16] + self.out_h = 64 + self.out_w = 32 + self.scale = 2.0 + self.out_size = None + self.align_corners = True + self.scale_by_1Dtensor = True + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestBilinearInterpOpAPI(unittest.TestCase): + def test_case(self): + x = fluid.data(name="x", shape=[2, 3, 6, 6], dtype="float32") + + dim = fluid.data(name="dim", shape=[1], dtype="int32") + shape_tensor = fluid.data(name="shape_tensor", shape=[2], dtype="int32") + actual_size = fluid.data(name="actual_size", shape=[2], dtype="int32") + scale_tensor = fluid.data( + name="scale_tensor", shape=[1], dtype="float32") + + out1 = fluid.layers.resize_bilinear(x, out_shape=[12, 12]) + out2 = fluid.layers.resize_bilinear(x, out_shape=[12, dim]) + out3 = fluid.layers.resize_bilinear(x, out_shape=shape_tensor) + out4 = fluid.layers.resize_bilinear( + x, out_shape=[4, 4], actual_shape=actual_size) + out5 = fluid.layers.resize_bilinear(x, scale=scale_tensor) + + x_data = np.random.random((2, 3, 6, 6)).astype("float32") + dim_data = np.array([12]).astype("int32") + shape_data = np.array([12, 12]).astype("int32") + actual_size_data = np.array([12, 12]).astype("int32") + scale_data = np.array([2.0]).astype("float32") + + place = core.XPUPlace(0) + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + results = exe.run(fluid.default_main_program(), + feed={ + "x": x_data, + "dim": dim_data, + "shape_tensor": shape_data, + "actual_size": actual_size_data, + "scale_tensor": scale_data + }, + fetch_list=[out1, out2, out3, out4, out5], + return_numpy=True) + + expect_res = bilinear_interp_np( + x_data, out_h=12, out_w=12, align_corners=True) + for res in results: + self.assertTrue(np.allclose(res, expect_res)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_concat_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_concat_op_xpu.py index bb5d7134a1b..b1a5e422acf 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_concat_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_concat_op_xpu.py @@ -19,16 +19,20 @@ import sys sys.path.append("..") import unittest import numpy as np + from op_test import OpTest, skip_check_grad_ci +from op_test_xpu import XPUOpTest import paddle.fluid as fluid from paddle.fluid import compiler, Program, program_guard, core import paddle -class TestConcatOp(OpTest): +class TestConcatOp(XPUOpTest): def setUp(self): self.op_type = "concat" self.dtype = self.get_dtype() + self.use_xpu = True + self.use_mkldnn = False self.init_test_data() self.inputs = {'X': [('x0', self.x0), ('x1', self.x1), ('x2', self.x2)]} self.attrs = {'axis': self.axis} @@ -44,7 +48,7 @@ class TestConcatOp(OpTest): } def get_dtype(self): - return "float64" + return "float32" def test_check_output(self): if paddle.is_compiled_with_xpu(): @@ -131,7 +135,7 @@ class TestConcatOp6(TestConcatOp): def test_check_output(self): if paddle.is_compiled_with_xpu(): place = paddle.XPUPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(place, check_dygraph=False) def test_check_grad(self): if paddle.is_compiled_with_xpu(): @@ -147,94 +151,6 @@ class TestConcatOp6(TestConcatOp): self.axis = 0 -class TestConcatOpError(unittest.TestCase): - def test_errors(self): - with program_guard(Program(), Program()): - # The input type of concat_op should be list. - x1 = fluid.layers.data(shape=[4], dtype='int32', name='x1') - fluid.layers.concat(x1) - # The item in input must be Variable. - x2 = fluid.create_lod_tensor( - np.array([[-1]]), [[1]], fluid.CPUPlace()) - x3 = fluid.create_lod_tensor( - np.array([[-1]]), [[1]], fluid.CPUPlace()) - self.assertRaises(TypeError, fluid.layers.concat, [x2]) - # The input dtype of concat_op must be float16, float32, float64, int32, int64. - x4 = fluid.layers.data(shape=[4], dtype='uint8', name='x4') - x5 = fluid.layers.data(shape=[4], dtype='uint8', name='x5') - self.assertRaises(TypeError, fluid.layers.concat, [x4, x5]) - x6 = fluid.layers.data(shape=[4], dtype='float16', name='x6') - x7 = fluid.layers.data(shape=[4], dtype='float16', name='x7') - x8 = fluid.layers.data(shape=[4], dtype='float32', name='x8') - fluid.layers.concat([x6, x7]) - - # The type of axis in concat_op should be int or Variable. - def test_axis_type(): - fluid.layers.concat([x6, x7], 3.2) - - self.assertRaises(TypeError, test_axis_type) - - def test_input_same_dtype(): - fluid.layers.concat([x7, x8]) - - self.assertRaises(TypeError, test_input_same_dtype) - - -class TestConcatAPI(unittest.TestCase): - def test_fluid_api(self): - x_1 = fluid.data(shape=[None, 1, 4, 5], dtype='float32', name='x_1') - fluid.layers.concat([x_1, x_1], 0) - - input_2 = np.random.random([2, 1, 4, 5]).astype("float32") - input_3 = np.random.random([2, 2, 4, 5]).astype("float32") - x_2 = fluid.data(shape=[2, 1, 4, 5], dtype='float32', name='x_2') - x_3 = fluid.data(shape=[2, 2, 4, 5], dtype='float32', name='x_3') - positive_1_int32 = fluid.layers.fill_constant([1], "float32", 1) - positive_1_int64 = fluid.layers.fill_constant([1], "float32", 1) - out_1 = fluid.layers.concat(input=[x_2, x_3], axis=1) - out_2 = fluid.layers.concat(input=[x_2, x_3], axis=1) - out_3 = fluid.layers.concat(input=[x_2, x_3], axis=1) - - exe = fluid.Executor(place=fluid.XPUPlace(0)) - [res_1, res_2, res_3] = exe.run( - fluid.default_main_program(), - feed={"x_1": input_2, - "x_2": input_2, - "x_3": input_3}, - fetch_list=[out_1, out_2, out_3]) - assert np.array_equal(res_1, np.concatenate((input_2, input_3), axis=1)) - assert np.array_equal(res_2, np.concatenate((input_2, input_3), axis=1)) - assert np.array_equal(res_3, np.concatenate((input_2, input_3), axis=1)) - - def test_errors(self): - with program_guard(Program(), Program()): - # The item in input must be Variable. - x2 = fluid.create_lod_tensor( - np.array([[-1]]), [[1]], fluid.XPUPlace(0)) - x3 = fluid.create_lod_tensor( - np.array([[-1]]), [[1]], fluid.XPUPlace(0)) - self.assertRaises(TypeError, paddle.concat, [x2]) - # The input dtype of concat_op must be float32. - x4 = fluid.data(shape=[4], dtype='uint8', name='x4') - x5 = fluid.data(shape=[4], dtype='uint8', name='x5') - self.assertRaises(TypeError, fluid.layers.concat, [x4, x5]) - - # The type of axis in concat_op should be int or Variable. - x6 = fluid.layers.data(shape=[4], dtype='float16', name='x6') - x7 = fluid.layers.data(shape=[4], dtype='float16', name='x7') - x8 = fluid.layers.data(shape=[4], dtype='float32', name='x8') - - def test_axis_type(): - paddle.concat([x6, x7], 3.2) - - self.assertRaises(TypeError, test_axis_type) - - def test_input_same_dtype(): - paddle.concat([x7, x8]) - - self.assertRaises(TypeError, test_input_same_dtype) - - if __name__ == '__main__': paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_deformable_conv_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_deformable_conv_op_xpu.py new file mode 100644 index 00000000000..5c611b62998 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_deformable_conv_op_xpu.py @@ -0,0 +1,274 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import sys +sys.path.append("..") +import unittest +import numpy as np + +import paddle.fluid.core as core +import paddle.fluid as fluid +from op_test_xpu import OpTest, XPUOpTest +import paddle +from paddle.fluid import Program, program_guard + + +def dmc_bilinear(data_im, height, width, h, w): + h_low = int(np.floor(h)) + w_low = int(np.floor(w)) + h_high = h_low + 1 + w_high = w_low + 1 + + lh = h - h_low + lw = w - w_low + hh = 1 - lh + hw = 1 - lw + + v1 = 0 + if h_low >= 0 and w_low >= 0: + v1 = data_im[h_low, w_low] + v2 = 0 + if h_low >= 0 and w_high <= width - 1: + v2 = data_im[h_low, w_high] + v3 = 0 + if h_high <= height - 1 and w_low >= 0: + v3 = data_im[h_high, w_low] + v4 = 0 + if h_high <= height - 1 and w_high <= width - 1: + v4 = data_im[h_high, w_high] + + w1, w2, w3, w4 = hh * hw, hh * lw, lh * hw, lh * lw + val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4 + + return val + + +def dconv_im2col_gemm(input, offset, mask, filter, group, conv_param): + in_n, in_c, in_h, in_w = input.shape + out_c, f_c, f_h, f_w = filter.shape + + assert offset.shape == (in_n, 2 * f_h * f_w, in_h, in_w) + assert mask.shape == (in_n, f_h * f_w, in_h, in_w) + assert f_c * group == in_c + assert np.mod(out_c, group) == 0 + + stride, pad, dilation = conv_param['stride'], conv_param['pad'],\ + conv_param['dilation'] + out_h = 1 + (in_h + 2 * pad[0] - (dilation[0] * (f_h - 1) + 1)) // stride[0] + out_w = 1 + (in_w + 2 * pad[1] - (dilation[1] * (f_w - 1) + 1)) // stride[1] + assert out_h == in_h + assert out_w == in_w + + col_buffer = np.zeros((in_n, in_c * f_h * f_w, in_h * in_w)) + for n in range(in_n): + for c in range(in_c): + for h in range(out_h): + for w in range(out_w): + for kh in range(f_h): + for kw in range(f_w): + offset_h_table = \ + offset[n, ::2, h, w].reshape(f_h, f_w) + offset_w_table = \ + offset[n, 1::2, h, w].reshape(f_h, f_w) + mask_table = \ + mask[n, :, h, w].reshape(f_h, f_w) + offset_h = offset_h_table[kh, kw] + offset_w = offset_w_table[kh, kw] + val = 0 + im_h = h * stride[0] + kh * dilation[0] \ + + offset_h - pad[0] + im_w = w * stride[0] + kw * dilation[0] \ + + offset_w - pad[1] + if im_h > -1 and im_w > -1 and \ + im_h < in_h and im_w < in_h: + val = dmc_bilinear(input[n, c], in_h, in_w, + im_h, im_w) + val_out = val * mask_table[kh, kw] + col_buffer[n, c * f_h * f_w + kh * f_w + kw, h * + in_w + w] = val_out + + out = np.zeros((in_n, group, int(out_c // group), out_h * out_w)) + weight = filter.reshape(group, int(out_c // group), f_c * f_h * f_w) + col_buffer = col_buffer.reshape( + (in_n, group, int(in_c // group * f_h * f_w), in_h * in_w)) + for n in range(in_n): + for g in range(group): + out[n, g] = np.matmul(weight[g], col_buffer[n, g]) + out = out.reshape(in_n, out_c, out_h, out_w) + return out + + +class TestModulatedDeformableConvOp(XPUOpTest): + def setUp(self): + self.op_type = "deformable_conv" + self.dtype = np.float32 + self.init_group() + self.init_dilation() + self.init_test_case() + + conv_param = { + 'stride': self.stride, + 'pad': self.pad, + 'dilation': self.dilations + } + + input = np.random.random(self.input_size).astype(self.dtype) + offset = 10 * np.random.random(self.offset_size).astype(self.dtype) + mask = 10 * np.random.random(self.mask_size).astype(self.dtype) + filter = np.random.random(self.filter_size).astype(self.dtype) + output = dconv_im2col_gemm(input, offset, mask, filter, self.groups, + conv_param) + output = output.astype(self.dtype) + + self.inputs = { + 'Input': OpTest.np_dtype_to_fluid_dtype(input), + 'Offset': OpTest.np_dtype_to_fluid_dtype(offset), + 'Mask': OpTest.np_dtype_to_fluid_dtype(mask), + 'Filter': OpTest.np_dtype_to_fluid_dtype(filter) + } + self.attrs = { + 'strides': self.stride, + 'paddings': self.pad, + 'groups': self.groups, + 'deformable_groups': self.deformable_groups, + 'im2col_step': self.im2col_step, + 'dilations': self.dilations, + } + self.outputs = {'Output': output} + + def has_cuda(self): + return core.is_compiled_with_cuda() and (self.use_cudnn or + self.use_cuda) + + def test_check_output(self): + if core.is_compiled_with_xpu(): + paddle.enable_static() + place = paddle.XPUPlace(0) + self.check_output_with_place(place) + + def test_check_grad(self): + if core.is_compiled_with_xpu(): + paddle.enable_static() + place = paddle.XPUPlace(0) + self.check_grad_with_place( + place, {'Input', 'Offset', 'Mask', 'Filter'}, + 'Output', + max_relative_error=0.06) + + def init_test_case(self): + self.pad = [1, 1] + self.stride = [1, 1] + self.dilations = [1, 1] + self.input_size = [2, 8, 4, 4] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [8, f_c, 3, 3] + self.im2col_step = 1 + self.deformable_groups = 1 + offset_c = 2 * self.deformable_groups * self.filter_size[ + 2] * self.filter_size[3] + mask_c = self.deformable_groups * self.filter_size[ + 2] * self.filter_size[3] + self.offset_size = [ + self.input_size[0], offset_c, self.input_size[2], self.input_size[3] + ] + self.mask_size = [ + self.input_size[0], mask_c, self.input_size[2], self.input_size[3] + ] + + def init_dilation(self): + self.dilations = [1, 1] + + def init_group(self): + self.groups = 1 + + +class TestWithDilation(TestModulatedDeformableConvOp): + def init_test_case(self): + self.pad = [2, 2] + self.stride = [1, 1] + self.input_size = [4, 3, 4, 4] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3] + self.im2col_step = 1 + self.deformable_groups = 1 + offset_c = 2 * self.deformable_groups * self.filter_size[ + 2] * self.filter_size[3] + mask_c = self.deformable_groups * self.filter_size[ + 2] * self.filter_size[3] + self.offset_size = [ + self.input_size[0], offset_c, self.input_size[2], self.input_size[3] + ] + self.mask_size = [ + self.input_size[0], mask_c, self.input_size[2], self.input_size[3] + ] + + def init_dilation(self): + self.dilations = [2, 2] + + +class TestWith3x3(TestModulatedDeformableConvOp): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3] + self.im2col_step = 1 + self.deformable_groups = 1 + offset_c = 2 * self.deformable_groups * self.filter_size[ + 2] * self.filter_size[3] + mask_c = self.deformable_groups * self.filter_size[ + 2] * self.filter_size[3] + self.offset_size = [ + self.input_size[0], offset_c, self.input_size[2], self.input_size[3] + ] + self.mask_size = [ + self.input_size[0], mask_c, self.input_size[2], self.input_size[3] + ] + + +class TestModulatedDeformableConvInvalidInput(unittest.TestCase): + def test_error(self): + def test_invalid_input(): + paddle.enable_static() + input = [1, 3, 32, 32] + offset = fluid.data( + name='offset', shape=[None, 3, 32, 32], dtype='float32') + mask = fluid.data( + name='mask', shape=[None, 3, 32, 32], dtype='float32') + loss = fluid.layers.deformable_conv( + input, offset, mask, num_filters=4, filter_size=1) + + self.assertRaises(TypeError, test_invalid_input) + + def test_invalid_offset(): + paddle.enable_static() + input = fluid.data( + name='input', shape=[None, 3, 32, 32], dtype='int32') + offset = fluid.data( + name='offset', shape=[None, 3, 32, 32], dtype='float32') + mask = fluid.data( + name='mask', shape=[None, 3, 32, 32], dtype='float32') + loss = fluid.layers.deformable_conv( + input, offset, mask, num_filters=4, filter_size=1) + + self.assertRaises(TypeError, test_invalid_offset) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_reduce_sum_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_reduce_sum_op_xpu.py index 2a0457d1862..638da601a3d 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_reduce_sum_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_reduce_sum_op_xpu.py @@ -18,7 +18,8 @@ import unittest import numpy as np import sys sys.path.append("..") -from op_test import OpTest, skip_check_grad_ci +from op_test_xpu import OpTest, XPUOpTest +from op_test import skip_check_grad_ci import paddle import paddle.fluid.core as core import paddle.fluid as fluid @@ -26,180 +27,128 @@ from paddle.fluid import compiler, Program, program_guard from paddle.fluid.framework import convert_np_dtype_to_dtype_ -class TestSumOp(OpTest): +class TestXPUReduceSumOp(XPUOpTest): def setUp(self): - self.op_type = "reduce_sum" - self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")} - self.attrs = {'use_xpu': True} - self.outputs = {'Out': self.inputs['X'].sum(axis=0)} - - def test_check_output(self): - if paddle.is_compiled_with_xpu(): - place = paddle.XPUPlace(0) - self.check_output_with_place(place) - - def check_grad_(self): - self.check_grad(['X'], 'Out') - - -class TestSumOp5D(OpTest): - def setUp(self): - self.op_type = "reduce_sum" - self.inputs = { - 'X': np.random.random((1, 2, 5, 6, 10)).astype("float64") + self.init_op_type() + self.initTestCase() + self.use_xpu = True + self.use_mkldnn = False + self.attrs = { + 'dim': self.axis, + 'keep_dim': self.keep_dim, + 'reduce_all': self.reduce_all } - self.attrs = {'use_xpu': True} - self.outputs = {'Out': self.inputs['X'].sum(axis=0)} + self.inputs = {'X': np.random.random(self.shape).astype("float32")} + if self.attrs['reduce_all']: + self.outputs = {'Out': self.inputs['X'].sum()} + else: + self.outputs = { + 'Out': self.inputs['X'].sum(axis=self.axis, + keepdims=self.attrs['keep_dim']) + } def test_check_output(self): if paddle.is_compiled_with_xpu(): + paddle.enable_static() place = paddle.XPUPlace(0) self.check_output_with_place(place) def test_check_grad(self): - self.check_grad(['X'], 'Out') - - -class TestSumOp6D(OpTest): - def setUp(self): - self.op_type = "reduce_sum" - self.inputs = { - 'X': np.random.random((1, 1, 2, 5, 6, 10)).astype("float64") - } - self.attrs = {'use_xpu': True} - self.outputs = {'Out': self.inputs['X'].sum(axis=0)} - - def test_check_output(self): if paddle.is_compiled_with_xpu(): + paddle.enable_static() place = paddle.XPUPlace(0) - self.check_output_with_place(place) + self.check_grad_with_place(place, ['X'], 'Out') - def test_check_grad(self): - self.check_grad(['X'], 'Out') + def init_op_type(self): + self.op_type = "reduce_sum" + self.use_mkldnn = False + self.keep_dim = False + self.reduce_all = False + def initTestCase(self): + self.shape = (5, 6, 10) + self.axis = (0, ) -class TestSumOp8D(OpTest): - def setUp(self): - self.op_type = "reduce_sum" - self.inputs = { - 'X': np.random.random((1, 3, 1, 2, 1, 4, 3, 10)).astype("float64") - } - self.attrs = {'dim': (0, 3), 'use_xpu': True} - self.outputs = {'Out': self.inputs['X'].sum(axis=(0, 3))} - def test_check_output(self): - if paddle.is_compiled_with_xpu(): - place = paddle.XPUPlace(0) - self.check_output_with_place(place) +class TestSumOp5D(TestXPUReduceSumOp): + def initTestCase(self): + self.shape = (1, 2, 5, 6, 10) + self.axis = (0, ) - def test_check_grad(self): - self.check_grad(['X'], 'Out') +class TestSumOp6D(TestXPUReduceSumOp): + def initTestCase(self): + self.shape = (1, 1, 2, 5, 6, 10) + self.axis = (0, ) -class Test1DReduce(OpTest): - def setUp(self): - self.op_type = "reduce_sum" - self.inputs = {'X': np.random.random(120).astype("float64")} - self.attrs = {'use_xpu': True} - self.outputs = {'Out': self.inputs['X'].sum(axis=0)} - def test_check_output(self): - if paddle.is_compiled_with_xpu(): - place = paddle.XPUPlace(0) - self.check_output_with_place(place) +class TestSumOp8D(TestXPUReduceSumOp): + def initTestCase(self): + self.shape = (1, 3, 1, 2, 1, 4, 3, 10) + self.axis = (0, 3) - def test_check_grad(self): - self.check_grad(['X'], 'Out') +class Test1DReduce(TestXPUReduceSumOp): + def initTestCase(self): + self.shape = 120 + self.axis = (0, ) -class Test2DReduce0(Test1DReduce): - def setUp(self): - self.op_type = "reduce_sum" - self.attrs = {'dim': [0], 'use_xpu': True} - self.inputs = {'X': np.random.random((20, 10)).astype("float64")} - self.outputs = {'Out': self.inputs['X'].sum(axis=0)} +class Test2DReduce0(TestXPUReduceSumOp): + def initTestCase(self): + self.shape = (20, 10) + self.axis = (0, ) -class Test2DReduce1(Test1DReduce): - def setUp(self): - self.op_type = "reduce_sum" - self.attrs = {'dim': [1], 'use_xpu': True} - self.inputs = {'X': np.random.random((20, 10)).astype("float64")} - self.outputs = { - 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim'])) - } +class Test2DReduce1(TestXPUReduceSumOp): + def initTestCase(self): + self.shape = (20, 10) + self.axis = (1, ) -class Test3DReduce0(Test1DReduce): - def setUp(self): - self.op_type = "reduce_sum" - self.attrs = {'dim': [1], 'use_xpu': True} - self.inputs = {'X': np.random.random((5, 6, 7)).astype("float64")} - self.outputs = { - 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim'])) - } +class Test3DReduce0(TestXPUReduceSumOp): + def initTestCase(self): + self.shape = (5, 6, 7) + self.axis = (1, ) -class Test3DReduce1(Test1DReduce): - def setUp(self): - self.op_type = "reduce_sum" - self.attrs = {'dim': [2], 'use_xpu': True} - self.inputs = {'X': np.random.random((5, 6, 7)).astype("float64")} - self.outputs = { - 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim'])) - } +class Test3DReduce1(TestXPUReduceSumOp): + def initTestCase(self): + self.shape = (5, 6, 7) + self.axis = (2, ) -class Test3DReduce2(Test1DReduce): - def setUp(self): - self.op_type = "reduce_sum" - self.attrs = {'dim': [-2], 'use_xpu': True} - self.inputs = {'X': np.random.random((5, 6, 7)).astype("float64")} - self.outputs = { - 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim'])) - } +class Test3DReduce2(TestXPUReduceSumOp): + def initTestCase(self): + self.shape = (5, 6, 7) + self.axis = (-2, ) -class Test3DReduce3(Test1DReduce): - def setUp(self): - self.op_type = "reduce_sum" - self.attrs = {'dim': [1, 2], 'use_xpu': True} - self.inputs = {'X': np.random.random((5, 6, 7)).astype("float64")} - self.outputs = { - 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim'])) - } +class Test3DReduce3(TestXPUReduceSumOp): + def initTestCase(self): + self.shape = (5, 6, 7) + self.axis = (1, 2) -class TestKeepDimReduce(Test1DReduce): - def setUp(self): - self.op_type = "reduce_sum" - self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")} - self.attrs = {'dim': [1], 'keep_dim': True, 'use_xpu': True} - self.outputs = { - 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']), - keepdims=self.attrs['keep_dim']) - } +class TestKeepDimReduce(TestXPUReduceSumOp): + def initTestCase(self): + self.shape = (5, 6, 10) + self.axis = (1, ) + self.keep_dim = True -class TestKeepDim8DReduce(Test1DReduce): - def setUp(self): - self.op_type = "reduce_sum" - self.inputs = { - 'X': np.random.random((2, 5, 3, 2, 2, 3, 4, 2)).astype("float64") - } - self.attrs = {'dim': (3, 4, 5), 'keep_dim': True, 'use_xpu': True} - self.outputs = { - 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']), - keepdims=self.attrs['keep_dim']) - } +class TestKeepDim8DReduce(TestXPUReduceSumOp): + def initTestCase(self): + self.shape = (2, 5, 3, 2, 2, 3, 4, 2) + self.axis = (3, 4, 5) + self.keep_dim = True -class TestReduceAll(Test1DReduce): - def setUp(self): - self.op_type = "reduce_sum" - self.inputs = {'X': np.random.random((5, 6, 2, 10)).astype("float64")} - self.attrs = {'reduce_all': True, 'use_xpu': True} - self.outputs = {'Out': self.inputs['X'].sum()} + +class TestReduceAll(TestXPUReduceSumOp): + def initTestCase(self): + self.shape = (5, 6, 2, 10) + self.axis = (0, ) + self.reduce_all = True if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/xpu/test_softmax_with_cross_entropy_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_softmax_with_cross_entropy_op_xpu.py index 5a8985315ea..454ef0db052 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_softmax_with_cross_entropy_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_softmax_with_cross_entropy_op_xpu.py @@ -13,15 +13,16 @@ # limitations under the License. from __future__ import print_function +import sys +sys.path.append("..") + from test_softmax_op import stable_softmax -from op_test import OpTest +from op_test_xpu import XPUOpTest import paddle.fluid.core as core import paddle import unittest import numpy as np -import sys -sys.path.append("..") def cross_entropy(softmax, label, soft_label, axis, ignore_index=-1): @@ -44,7 +45,7 @@ def cross_entropy(softmax, label, soft_label, axis, ignore_index=-1): return result.reshape(label.shape) -class TestSoftmaxWithCrossEntropyOp(OpTest): +class TestSoftmaxWithCrossEntropyOp(XPUOpTest): """ Test softmax with cross entropy operator with discreate one-hot labels. """ diff --git a/python/paddle/fluid/tests/unittests/xpu/test_transpose_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_transpose_op_xpu.py index c191e5f0b29..41df4481e2d 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_transpose_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_transpose_op_xpu.py @@ -19,24 +19,27 @@ import numpy as np import sys sys.path.append("..") -from op_test import OpTest +from op_test_xpu import OpTest, XPUOpTest import paddle +import paddle.fluid.core as core import paddle.fluid as fluid from paddle.fluid import compiler, Program, program_guard -class TestXPUTransposeOp(OpTest): +class TestXPUTransposeOp(XPUOpTest): def setUp(self): self.init_op_type() self.initTestCase() - self.inputs = {'X': np.random.random(self.shape).astype("float64")} + self.use_xpu = True + self.use_mkldnn = False + self.inputs = {'X': np.random.random(self.shape).astype("float32")} self.attrs = { 'axis': list(self.axis), 'use_mkldnn': False, 'use_xpu': True } self.outputs = { - 'XShape': np.random.random(self.shape).astype("float64"), + 'XShape': np.random.random(self.shape).astype("float32"), 'Out': self.inputs['X'].transpose(self.axis) } @@ -121,110 +124,5 @@ class TestCase9(TestXPUTransposeOp): self.axis = (6, 1, 3, 5, 0, 2, 4, 7) -class TestTransposeOpError(unittest.TestCase): - def test_errors(self): - with program_guard(Program(), Program()): - x = fluid.layers.data(name='x', shape=[10, 5, 3], dtype='float64') - - def test_x_Variable_check(): - # the Input(x)'s type must be Variable - fluid.layers.transpose("not_variable", perm=[1, 0, 2]) - - self.assertRaises(TypeError, test_x_Variable_check) - - def test_x_dtype_check(): - # the Input(x)'s dtype must be one of [float16, float32, float64, int32, int64] - x1 = fluid.layers.data( - name='x1', shape=[10, 5, 3], dtype='bool') - fluid.layers.transpose(x1, perm=[1, 0, 2]) - - self.assertRaises(TypeError, test_x_dtype_check) - - def test_perm_list_check(): - # Input(perm)'s type must be list - fluid.layers.transpose(x, perm="[1, 0, 2]") - - self.assertRaises(TypeError, test_perm_list_check) - - def test_perm_length_and_x_dim_check(): - # Input(perm) is the permutation of dimensions of Input(input) - # its length should be equal to dimensions of Input(input) - fluid.layers.transpose(x, perm=[1, 0, 2, 3, 4]) - - self.assertRaises(ValueError, test_perm_length_and_x_dim_check) - - def test_each_elem_value_check(): - # Each element in Input(perm) should be less than Input(x)'s dimension - fluid.layers.transpose(x, perm=[3, 5, 7]) - - self.assertRaises(ValueError, test_each_elem_value_check) - - -class TestTAPI(unittest.TestCase): - def test_out(self): - with fluid.program_guard(fluid.Program()): - data = fluid.data(shape=[10], dtype="float64", name="data") - data_t = paddle.t(data) - place = fluid.CPUPlace() - exe = fluid.Executor(place) - data_np = np.random.random([10]).astype("float64") - result, = exe.run(feed={"data": data_np}, fetch_list=[data_t]) - expected_result = np.transpose(data_np) - self.assertEqual((result == expected_result).all(), True) - - with fluid.program_guard(fluid.Program()): - data = fluid.data(shape=[10, 5], dtype="float64", name="data") - data_t = paddle.t(data) - place = fluid.CPUPlace() - exe = fluid.Executor(place) - data_np = np.random.random([10, 5]).astype("float64") - result, = exe.run(feed={"data": data_np}, fetch_list=[data_t]) - expected_result = np.transpose(data_np) - self.assertEqual((result == expected_result).all(), True) - - with fluid.program_guard(fluid.Program()): - data = fluid.data(shape=[1, 5], dtype="float64", name="data") - data_t = paddle.t(data) - place = fluid.CPUPlace() - exe = fluid.Executor(place) - data_np = np.random.random([1, 5]).astype("float64") - result, = exe.run(feed={"data": data_np}, fetch_list=[data_t]) - expected_result = np.transpose(data_np) - self.assertEqual((result == expected_result).all(), True) - - with fluid.dygraph.guard(): - np_x = np.random.random([10]).astype("float64") - data = fluid.dygraph.to_variable(np_x) - z = paddle.t(data) - np_z = z.numpy() - z_expected = np.array(np.transpose(np_x)) - self.assertEqual((np_z == z_expected).all(), True) - - with fluid.dygraph.guard(): - np_x = np.random.random([10, 5]).astype("float64") - data = fluid.dygraph.to_variable(np_x) - z = paddle.t(data) - np_z = z.numpy() - z_expected = np.array(np.transpose(np_x)) - self.assertEqual((np_z == z_expected).all(), True) - - with fluid.dygraph.guard(): - np_x = np.random.random([1, 5]).astype("float64") - data = fluid.dygraph.to_variable(np_x) - z = paddle.t(data) - np_z = z.numpy() - z_expected = np.array(np.transpose(np_x)) - self.assertEqual((np_z == z_expected).all(), True) - - def test_errors(self): - with fluid.program_guard(fluid.Program()): - x = fluid.data(name='x', shape=[10, 5, 3], dtype='float64') - - def test_x_dimension_check(): - paddle.t(x) - - self.assertRaises(ValueError, test_x_dimension_check) - - if __name__ == "__main__": unittest.main() -- GitLab