From 64f29fbb705228c7d5c921ebe9a5207926337d09 Mon Sep 17 00:00:00 2001 From: QingshuChen Date: Tue, 1 Dec 2020 13:49:36 +0800 Subject: [PATCH] update kunlun conv2d/softmax/elementwise implemetation (#29229) * update conv2d & softmax to new xpu api * test=kunlun * remove useless comments * test=kunlun * remote softmax xpu op * test=kunlun * update kunlun softmax * test=kunlun * update xpu unitest * test=kunlun * fix elementwise_grad bug for kunlun *test=kunlun --- cmake/external/xpu.cmake | 2 +- paddle/fluid/operators/conv_op_xpu.cc | 125 ++++-------------- .../operators/elementwise/elementwise_xpu.h | 52 ++++---- paddle/fluid/operators/softmax_op_xpu.cc | 51 ++++--- paddle/fluid/platform/xpu_header.h | 8 ++ python/paddle/fluid/io.py | 4 + .../fluid/tests/unittests/op_test_xpu.py | 11 -- .../test_softmax_with_cross_entropy_op_xpu.py | 5 +- 8 files changed, 95 insertions(+), 163 deletions(-) diff --git a/cmake/external/xpu.cmake b/cmake/external/xpu.cmake index 8d3fee915c4..ff8a3b9838a 100644 --- a/cmake/external/xpu.cmake +++ b/cmake/external/xpu.cmake @@ -4,7 +4,7 @@ endif() INCLUDE(ExternalProject) SET(XPU_PROJECT "extern_xpu") -SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2020_11_10.tar.gz" CACHE STRING "" FORCE) +SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2020_11_30.tar.gz" CACHE STRING "" FORCE) SET(XPU_SOURCE_DIR "${THIRD_PARTY_PATH}/xpu") SET(XPU_DOWNLOAD_DIR "${XPU_SOURCE_DIR}/src/${XPU_PROJECT}") SET(XPU_INSTALL_DIR "${THIRD_PARTY_PATH}/install/xpu") diff --git a/paddle/fluid/operators/conv_op_xpu.cc b/paddle/fluid/operators/conv_op_xpu.cc index 65ed34e8a5e..46af4d30500 100644 --- a/paddle/fluid/operators/conv_op_xpu.cc +++ b/paddle/fluid/operators/conv_op_xpu.cc @@ -27,10 +27,6 @@ class GemmConvXPUKernel : public framework::OpKernel { // that avoids modifying the variable in the Scope. Tensor filter = *context.Input("Filter"); Tensor* output = context.Output("Output"); - // Tensor* max_input = context.Output("MaxInput"); - // Tensor* max_filter = context.Output("MaxFilter"); - // max_input->mutable_data(context.GetPlace()); - // max_filter->mutable_data(context.GetPlace()); output->mutable_data(context.GetPlace()); int groups = context.Attr("groups"); std::vector strides = context.Attr>("strides"); @@ -43,52 +39,18 @@ class GemmConvXPUKernel : public framework::OpKernel { const int f = static_cast(filter.dims()[0]); const int win_h = static_cast(filter.dims()[2]); const int win_w = static_cast(filter.dims()[3]); - PADDLE_ENFORCE_EQ( - dilations[0] == 1 && dilations[1] == 1, true, - platform::errors::InvalidArgument("XPU only support dilation == 1.")); auto& dev_ctx = context.template device_context(); - // PADDLE_ENFORCE_EQ( - // xpu::findmax(dev_ctx.x_context(), input->data(), input->numel(), - // max_input->data()) == xpu::Error_t::SUCCESS, - // true, platform::errors::InvalidArgument( - // "XPU conv kernel error,can not finde max_input,please " - // "check whether Baidu Kunlun " - // "Card is properly installed.")); - // PADDLE_ENFORCE_EQ( - // xpu::findmax(dev_ctx.x_context(), filter.data(), filter.numel(), - // max_filter->data()) == xpu::Error_t::SUCCESS, - // true, platform::errors::InvalidArgument( - // "XPU conv kernel error,can not find max_filter,please " - // "check whether Baidu Kunlun " - // "Card is properly installed.")); - if (groups == 1) { - int r = xpu::conv2d_forward_int16( - dev_ctx.x_context(), batch_size, img_c, img_h, img_w, f, win_h, win_w, - strides[0], strides[1], paddings[0], paddings[1], dilations[0], - dilations[1], groups, input->data(), filter.data(), - output->data(), nullptr, nullptr, xpu::Activation_t::LINEAR, - nullptr, nullptr); - // max_input->data(), max_filter->data()); - PADDLE_ENFORCE_EQ( - r, XPU_SUCCESS, - platform::errors::External("XPU conv kernel return wrong value[%d], " - "please check whether Baidu Kunlun Card " - "is properly installed.", - r)); - } else { - int r = xpu::conv2d_int16_with_group( - dev_ctx.x_context(), input->data(), filter.data(), - output->data(), batch_size, img_c, img_h, img_w, f, win_h, - win_w, groups, strides[0], strides[1], paddings[0], paddings[1], - nullptr, nullptr); - // max_input->data(), max_filter->data()); - PADDLE_ENFORCE_EQ( - r, XPU_SUCCESS, - platform::errors::External("XPU conv kernel return wrong value[%d], " - "please check whether Baidu Kunlun Card " - "is properly installed.", - r)); - } + std::vector k_size; + k_size.push_back(win_h); + k_size.push_back(win_w); + int r = xpu::conv2d( + dev_ctx.x_context(), input->data(), filter.data(), + output->data(), batch_size, img_c, img_h, img_w, f, k_size, + strides, paddings, dilations, groups, nullptr, nullptr, nullptr, true); + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External("XPU conv kernel return wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); } }; template @@ -96,9 +58,6 @@ class GemmConvGradXPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { const Tensor* input = context.Input("Input"); - // const Tensor* max_input = context.Input("MaxInput"); - // const Tensor* max_filter = context.Input("MaxFilter"); - // Tensor* max_output_grad = context.Output("MaxOutputGrad"); const Tensor* output_grad = context.Input(framework::GradVarName("Output")); Tensor* input_grad = @@ -115,11 +74,6 @@ class GemmConvGradXPUKernel : public framework::OpKernel { std::vector paddings = context.Attr>("paddings"); std::vector dilations = context.Attr>("dilations"); const int batch_size = static_cast(input->dims()[0]); - PADDLE_ENFORCE_EQ(groups == 1, true, platform::errors::InvalidArgument( - "XPU only support groups == 1.")); - PADDLE_ENFORCE_EQ( - dilations[0] == 1 && dilations[1] == 1, true, - platform::errors::InvalidArgument("XPU only support dilation == 1.")); const int img_c = static_cast(input->dims()[1]); const int img_h = static_cast(input->dims()[2]); const int img_w = static_cast(input->dims()[3]); @@ -133,52 +87,24 @@ class GemmConvGradXPUKernel : public framework::OpKernel { filter_grad->mutable_data(context.GetPlace()); } auto& dev_ctx = context.template device_context(); - // max_output_grad->Resize({4}); - // max_output_grad->mutable_data(context.GetPlace()); - // PADDLE_ENFORCE_EQ( - // xpu::findmax(dev_ctx.x_context(), output_grad->data(), - // output_grad->numel(), - // max_output_grad->data()) == xpu::Error_t::SUCCESS, - // true, - // platform::errors::External( - // "XPU conv kernel error, can not find max_output_grad, please - // check " - // "whether Baidu Kunlun Card is " - // "properly installed.")); - if (input_grad) { - int r = xpu::conv2d_backward_int16( - dev_ctx.x_context(), batch_size, img_c, img_h, img_w, f, win_h, win_w, - strides[0], strides[1], paddings[0], paddings[1], dilations[0], - dilations[1], groups, output_grad->data(), - filter.data(), input_grad->data(), nullptr, nullptr); - // max_output_grad->data(), max_filter->data()); - PADDLE_ENFORCE_EQ( - r, XPU_SUCCESS, - platform::errors::External("XPU conv kernel return wrong value[%d], " - "please check whether Baidu Kunlun Card " - "is properly installed.", - r)); - } - if (filter_grad) { - int r = xpu::conv2d_backward_weight_int16( - dev_ctx.x_context(), batch_size, img_c, img_h, img_w, f, win_h, win_w, - strides[0], strides[1], paddings[0], paddings[1], dilations[0], - dilations[1], groups, output_grad->data(), - input->data(), filter_grad->data(), nullptr, nullptr); - // max_output_grad->data(), max_input->data()); - PADDLE_ENFORCE_EQ( - r, XPU_SUCCESS, - platform::errors::External("XPU conv kernel return wrong value[%d], " - "please check whether Baidu Kunlun Card " - "is properly installed.", - r)); - } + std::vector k_size; + k_size.push_back(win_h); + k_size.push_back(win_w); + int r = xpu::conv2d_grad( + dev_ctx.x_context(), input->data(), filter.data(), + output_grad->data(), input_grad ? input_grad->data() : nullptr, + filter_grad ? filter_grad->data() : nullptr, batch_size, img_c, + img_h, img_w, f, k_size, strides, paddings, dilations, groups, nullptr, + nullptr, nullptr, nullptr, nullptr, true); + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External("XPU conv kernel return wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; -// TODO(xingzhaolong): neon kernel for mobile REGISTER_OP_XPU_KERNEL( depthwise_conv2d, ops::GemmConvXPUKernel); @@ -187,4 +113,7 @@ REGISTER_OP_XPU_KERNEL( REGISTER_OP_XPU_KERNEL( conv2d_grad, ops::GemmConvGradXPUKernel); +REGISTER_OP_XPU_KERNEL( + depthwise_conv2d_grad, + ops::GemmConvGradXPUKernel); #endif diff --git a/paddle/fluid/operators/elementwise/elementwise_xpu.h b/paddle/fluid/operators/elementwise/elementwise_xpu.h index fdf5aeeba53..89d8487fdbb 100644 --- a/paddle/fluid/operators/elementwise/elementwise_xpu.h +++ b/paddle/fluid/operators/elementwise/elementwise_xpu.h @@ -65,7 +65,7 @@ static std::pair, std::vector> XPUReducesAxisVector( } int yidx = 0; for (size_t i = 0; i < x_vector.size(); ++i) { - if (y[yidx] == 1) { + if (yidx >= y.size() || y[yidx] == 1) { axis_v.push_back(i); yidx++; continue; @@ -134,10 +134,10 @@ void XPUElementwise( std::pair, std::vector> bcast_v = XPUDimsToBroadcastVector(framework::make_ddim(x_dims_array), out_dim); - ret = xpu::broadcast( - dev_ctx.x_context(), x_data, - x_broadcast_tensor.mutable_data(ctx.GetPlace(), z->numel()), - bcast_v.first, bcast_v.second); + ret = xpu::broadcast(dev_ctx.x_context(), x_data, + x_broadcast_tensor.mutable_data( + ctx.GetPlace(), z->numel() * sizeof(T)), + bcast_v.first, bcast_v.second); PADDLE_ENFORCE_EQ( ret, xpu::SUCCESS, platform::errors::External( @@ -153,10 +153,10 @@ void XPUElementwise( std::vector bcast_y_v; std::pair, std::vector> bcast_v = XPUDimsToBroadcastVector(framework::make_ddim(y_dims_array), out_dim); - ret = xpu::broadcast( - dev_ctx.x_context(), y_data, - y_broadcast_tensor.mutable_data(ctx.GetPlace(), z->numel()), - bcast_v.first, bcast_v.second); + ret = xpu::broadcast(dev_ctx.x_context(), y_data, + y_broadcast_tensor.mutable_data( + ctx.GetPlace(), z->numel() * sizeof(T)), + bcast_v.first, bcast_v.second); PADDLE_ENFORCE_EQ( ret, xpu::SUCCESS, platform::errors::External( @@ -231,13 +231,15 @@ void XPUElementwiseGrad(const framework::ExecutionContext& ctx, bool dx_need_reduce = (dx != nullptr) && (dx->numel() != len); bool dy_need_reduce = (dy != nullptr) && (dy->numel() != len); - T* dx_data = ((dx == nullptr) || dx_need_reduce) - ? (dx_local_tensor.mutable_data(ctx.GetPlace(), len)) - : (dx->mutable_data(ctx.GetPlace())); + T* dx_data = + ((dx == nullptr) || dx_need_reduce) + ? (dx_local_tensor.mutable_data(ctx.GetPlace(), len * sizeof(T))) + : (dx->mutable_data(ctx.GetPlace())); - T* dy_data = ((dy == nullptr) || dy_need_reduce) - ? (dy_local_tensor.mutable_data(ctx.GetPlace(), len)) - : (dy->mutable_data(ctx.GetPlace())); + T* dy_data = + ((dy == nullptr) || dy_need_reduce) + ? (dy_local_tensor.mutable_data(ctx.GetPlace(), len * sizeof(T))) + : (dy->mutable_data(ctx.GetPlace())); int ret = xpu::SUCCESS; auto& dev_ctx = @@ -250,8 +252,8 @@ void XPUElementwiseGrad(const framework::ExecutionContext& ctx, XPUDimsToBroadcastVector(framework::make_ddim(x_dims_array), out_dim); ret = xpu::broadcast( dev_ctx.x_context(), x_data, - x_broadcast_tensor.mutable_data(ctx.GetPlace(), len), bcast_v.first, - bcast_v.second); + x_broadcast_tensor.mutable_data(ctx.GetPlace(), len * sizeof(T)), + bcast_v.first, bcast_v.second); PADDLE_ENFORCE_EQ(ret, xpu::SUCCESS, platform::errors::External( "XPU kernel broadcast error occur! %d", ret)); @@ -267,8 +269,8 @@ void XPUElementwiseGrad(const framework::ExecutionContext& ctx, XPUDimsToBroadcastVector(framework::make_ddim(y_dims_array), out_dim); ret = xpu::broadcast( dev_ctx.x_context(), y_data, - y_broadcast_tensor.mutable_data(ctx.GetPlace(), len), bcast_v.first, - bcast_v.second); + y_broadcast_tensor.mutable_data(ctx.GetPlace(), len * sizeof(T)), + bcast_v.first, bcast_v.second); PADDLE_ENFORCE_EQ(ret, xpu::SUCCESS, platform::errors::External( "XPU kernel broadcast error occur! %d", ret)); @@ -287,9 +289,9 @@ void XPUElementwiseGrad(const framework::ExecutionContext& ctx, const framework::DDim& dx_dims = dx->dims(); std::pair, std::vector> reduce_v = XPUReducesAxisVector(out_dim, dx_dims); - ret = xpu::reduce_sum(dev_ctx.x_context(), dx_data, - dx->mutable_data(ctx.GetPlace()), reduce_v.first, - reduce_v.second); + ret = xpu::reduce_sum(dev_ctx.x_context(), dx_data, + dx->mutable_data(ctx.GetPlace()), + reduce_v.first, reduce_v.second); PADDLE_ENFORCE_EQ( ret, xpu::SUCCESS, platform::errors::External("XPU kernel reduce_sum occur error in " @@ -302,9 +304,9 @@ void XPUElementwiseGrad(const framework::ExecutionContext& ctx, const framework::DDim& dy_dims = dy->dims(); std::pair, std::vector> reduce_v = XPUReducesAxisVector(out_dim, dy_dims); - ret = xpu::reduce_sum(dev_ctx.x_context(), dy_data, - dy->mutable_data(ctx.GetPlace()), reduce_v.first, - reduce_v.second); + ret = xpu::reduce_sum(dev_ctx.x_context(), dy_data, + dy->mutable_data(ctx.GetPlace()), + reduce_v.first, reduce_v.second); PADDLE_ENFORCE_EQ( ret, xpu::SUCCESS, platform::errors::External("XPU kernel reduce_sum occur error in " diff --git a/paddle/fluid/operators/softmax_op_xpu.cc b/paddle/fluid/operators/softmax_op_xpu.cc index 29740000aeb..312c5d2dde1 100644 --- a/paddle/fluid/operators/softmax_op_xpu.cc +++ b/paddle/fluid/operators/softmax_op_xpu.cc @@ -1,11 +1,8 @@ /* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -30,29 +27,27 @@ class SoftmaxXPUKernel : public framework::OpKernel { auto* x = context.Input("X"); auto* out = context.Output("Out"); const int rank = x->dims().size(); - const int axis = CanonicalAxis(context.Attr("axis"), rank); - PADDLE_ENFORCE_EQ(axis == -1 || axis == rank - 1, true, - platform::errors::InvalidArgument( - "xpu softmax kernel only support last dimension of x " - "(axis==-1 or axis==x_dims-1), but received axis: " - "%d, x's shape: %s.", - axis, x->dims())); + int axis = CanonicalAxis(context.Attr("axis"), rank); // allocate memory on device. out->mutable_data(context.GetPlace()); - const int n = SizeToAxis(axis, x->dims()); - const int d = SizeFromAxis(axis, x->dims()); + std::vector x_dims; + for (int i = 0; i < rank; i++) { + x_dims.push_back(x->dims()[i]); + } + if (axis < 0) { + axis += rank; + } auto& dev_ctx = context.template device_context(); - int r = xpu::softmax2d_forward(dev_ctx.x_context(), x->data(), - out->data(), n, d, d <= 2048); + int r = xpu::softmax(dev_ctx.x_context(), x->data(), + out->data(), x_dims, axis); PADDLE_ENFORCE_EQ( r, XPU_SUCCESS, platform::errors::External("XPU API(softmax2d_forward) return wrong " - "value[%d], please check whether " - "Baidu Kunlun Card is properly installed.", - r)); + "value[%d %s]", + r, XPUAPIErrorMsg[r])); } }; @@ -64,24 +59,28 @@ class SoftmaxGradXPUKernel : public framework::OpKernel { auto* dout = context.Input(framework::GradVarName("Out")); auto* dx = context.Output(framework::GradVarName("X")); const int rank = dx->dims().size(); - const int axis = CanonicalAxis(context.Attr("axis"), rank); + int axis = CanonicalAxis(context.Attr("axis"), rank); // allocate memory on device. dx->mutable_data(context.GetPlace()); - const int n = SizeToAxis(axis, dx->dims()); - const int d = SizeFromAxis(axis, dx->dims()); + std::vector x_dims; + for (int i = 0; i < rank; i++) { + x_dims.push_back(dx->dims()[i]); + } + if (axis < 0) { + axis += rank; + } auto& dev_ctx = context.template device_context(); - int r = - xpu::softmax2d_backward(dev_ctx.x_context(), out->data(), - dout->data(), dx->data(), n, d); + int r = xpu::softmax_grad(dev_ctx.x_context(), out->data(), + dout->data(), dx->data(), x_dims, + axis); PADDLE_ENFORCE_EQ( r, XPU_SUCCESS, platform::errors::External("XPU API(softmax2d_backward) return wrong " - "value[%d], please check whether " - "Baidu Kunlun Card is properly installed.", - r)); + "value[%d %s]", + r, XPUAPIErrorMsg[r])); } }; diff --git a/paddle/fluid/platform/xpu_header.h b/paddle/fluid/platform/xpu_header.h index 66982769837..bce82b897f0 100644 --- a/paddle/fluid/platform/xpu_header.h +++ b/paddle/fluid/platform/xpu_header.h @@ -15,6 +15,7 @@ #pragma once #ifdef PADDLE_WITH_XPU +#include #include #include @@ -48,4 +49,11 @@ class XPUActHelper { return res->second; } }; + +static std::map XPUAPIErrorMsg = { + {xpu::Error_t::SUCCESS, "xpu api success"}, + {xpu::Error_t::INVALID_PARAM, "xpu api invalid param"}, + {xpu::Error_t::RUNTIME_ERROR, "xpu api runtime error"}, + {xpu::Error_t::NO_ENOUGH_WORKSPACE, "xpu api no enough workspace"}}; + #endif diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 215b4cd039f..fdd236a58f0 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -1915,6 +1915,10 @@ def load(program, model_path, executor=None, var_list=None): place = paddle.fluid.CPUPlace() elif p.is_cuda_pinned_place(): place = paddle.fluid.CUDAPinnedPlace() + elif p.is_xpu_place(): + p = paddle.fluid.core.Place() + p.set_place(t._place()) + place = paddle.fluid.XPUPlace(p.xpu_device_id()) else: p = paddle.fluid.core.Place() p.set_place(t._place()) diff --git a/python/paddle/fluid/tests/unittests/op_test_xpu.py b/python/paddle/fluid/tests/unittests/op_test_xpu.py index 7e19d8e4d8a..37b446174d6 100644 --- a/python/paddle/fluid/tests/unittests/op_test_xpu.py +++ b/python/paddle/fluid/tests/unittests/op_test_xpu.py @@ -362,17 +362,6 @@ class XPUOpTest(OpTest): if not type(output_names) is list: output_names = [output_names] - numeric_grads = user_defined_grads or [ - get_numeric_gradient( - place, - self.scope, - self.op, - self.inputs, - input_to_check, - output_names, - delta=numeric_grad_delta, - in_place=in_place) for input_to_check in inputs_to_check - ] analytic_grads = self._get_gradient(inputs_to_check, place, output_names, no_grad_set) return analytic_grads diff --git a/python/paddle/fluid/tests/unittests/xpu/test_softmax_with_cross_entropy_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_softmax_with_cross_entropy_op_xpu.py index 5a8985315ea..f734d3c25a0 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_softmax_with_cross_entropy_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_softmax_with_cross_entropy_op_xpu.py @@ -13,6 +13,9 @@ # limitations under the License. from __future__ import print_function +import sys +sys.path.append("..") + from test_softmax_op import stable_softmax from op_test import OpTest import paddle.fluid.core as core @@ -20,8 +23,6 @@ import paddle import unittest import numpy as np -import sys -sys.path.append("..") def cross_entropy(softmax, label, soft_label, axis, ignore_index=-1): -- GitLab