From d19a9b3954f7e29356410824213806b7e27d37e4 Mon Sep 17 00:00:00 2001 From: taixiurong Date: Mon, 18 Oct 2021 11:24:04 +0800 Subject: [PATCH] [XPU AMP] 1. xpu support gradient acc 2. xpu support create tensor in dygraph 3. xpu support update weight params in amp (#36439) --- .../fluid/imperative/gradient_accumulator.cc | 47 ++++- .../reduce_ops/reduce_mean_op_xpu.cc | 99 ++++++++-- paddle/fluid/operators/slice_op_xpu.cc | 174 ++++++++---------- paddle/fluid/platform/xpu/xpu2_op_list.h | 11 +- python/paddle/fluid/framework.py | 12 ++ python/paddle/optimizer/adamw.py | 7 - python/paddle/tensor/creation.py | 4 +- 7 files changed, 238 insertions(+), 116 deletions(-) diff --git a/paddle/fluid/imperative/gradient_accumulator.cc b/paddle/fluid/imperative/gradient_accumulator.cc index fbc5453f821..fd6a070c3fc 100644 --- a/paddle/fluid/imperative/gradient_accumulator.cc +++ b/paddle/fluid/imperative/gradient_accumulator.cc @@ -87,9 +87,17 @@ class TensorAddFunctor : public boost::static_visitor<> { #ifdef PADDLE_WITH_XPU void operator()(const platform::XPUPlace& place) { + using XPUType = typename XPUTypeTrait::Type; platform::XPUDeviceContext* ctx = dynamic_cast( platform::DeviceContextPool::Instance().Get(place)); - xpu::add(ctx->x_context(), x_, y_, y_, static_cast(numel_)); + int r = xpu::add( + ctx->x_context(), reinterpret_cast(x_), + reinterpret_cast(y_), reinterpret_cast(y_), + static_cast(numel_)); + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External("XPU add kernel return wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); } #else void operator()(const platform::XPUPlace& place) { @@ -154,6 +162,24 @@ class TensorAddFunctor : public boost::static_visitor<> { T* y_; }; +#ifdef PADDLE_WITH_XPU +template +void XPUTensorAddFunctor(const platform::Place& place, + const framework::Tensor& src, framework::Tensor* dst) { + using XPUType = typename XPUTypeTrait::Type; + platform::XPUDeviceContext* ctx = dynamic_cast( + platform::DeviceContextPool::Instance().Get(place)); + const XPUType* x = reinterpret_cast(src.data()); + XPUType* y = reinterpret_cast(dst->mutable_data(place)); + int r = xpu::add(ctx->x_context(), x, y, y, + static_cast(src.numel())); + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External("XPU add kernel return wrong value[%d %s]", r, + XPUAPIErrorMsg[r])); +} +#endif + template void TensorAddImpl(const framework::Tensor& src, framework::Tensor* dst, const platform::Place& place) { @@ -226,7 +252,26 @@ void TensorAdd(const framework::Variable& src, framework::Variable* dst) { return; } #endif + +#ifdef PADDLE_WITH_XPU + if (platform::is_xpu_place(place)) { + if (data_type == framework::DataTypeTrait::DataType()) { + XPUTensorAddFunctor(place, src_tensor, dst_tensor); + } else if (data_type == + framework::DataTypeTrait::DataType()) { + XPUTensorAddFunctor(place, src_tensor, dst_tensor); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Gradient accumulation of data type (%s) on place (%s) is not " + "supported in imperative mode", + framework::DataTypeToString(data_type), place)); + } + return; + } +#endif + PADDLE_TENSOR_ADD(float); + #ifndef PADDLE_WITH_XPU // NOTE(phlrain): xpu only support float PADDLE_TENSOR_ADD(double); diff --git a/paddle/fluid/operators/reduce_ops/reduce_mean_op_xpu.cc b/paddle/fluid/operators/reduce_ops/reduce_mean_op_xpu.cc index b82ecbbe2fc..d6c1dc5f02d 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_mean_op_xpu.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_mean_op_xpu.cc @@ -23,30 +23,103 @@ namespace paddle { namespace operators { template class ReduceMeanXPUKernel : public framework::OpKernel { + using XPUType = typename XPUTypeTrait::Type; + public: void Compute(const framework::ExecutionContext& context) const override { PADDLE_ENFORCE_EQ( platform::is_xpu_place(context.GetPlace()), true, platform::errors::Unavailable("This kernel only runs on XPU.")); - // bool reduce_all = context.Attr("reduce_all"); + bool reduce_all = context.Attr("reduce_all"); auto* input = context.Input("X"); auto* output = context.Output("Out"); output->mutable_data(context.GetPlace()); auto& dev_ctx = context.template device_context(); - int ndim = input->dims().size(); - std::vector idims; + + std::vector xdims; for (int i = 0; i < input->dims().size(); i++) { - idims.push_back(input->dims()[i]); + xdims.push_back(input->dims()[i]); } - auto dims = context.Attr>("dim"); - int rdim = dims.size(); - int r = - xpu::reduce(dev_ctx.x_context(), input->data(), output->data(), - idims.data(), ndim, dims.data(), rdim, xpu::REDUCE_MEAN); - PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true, - platform::errors::External("XPU kernel error!")); + auto rdims = context.Attr>("dim"); + if (reduce_all) { + rdims.clear(); + for (size_t i = 0; i < xdims.size(); i++) { + rdims.push_back(static_cast(i)); + } + } + int r = xpu::reduce_mean( + dev_ctx.x_context(), reinterpret_cast(input->data()), + reinterpret_cast(output->data()), xdims, rdims); + + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, + platform::errors::External( + "XPU reduce_mean kernel return wrong value[%d %s]", r, + XPUAPIErrorMsg[r])); } }; + +template +class ReduceMeanGradXPUKernel : public framework::OpKernel { + using XPUType = typename XPUTypeTrait::Type; + + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto* output_grad = ctx.Input(framework::GradVarName("Out")); + auto* input_grad = ctx.Output(framework::GradVarName("X")); + + XPUType* x_data = + reinterpret_cast(input_grad->mutable_data(ctx.GetPlace())); + const XPUType* dy_data = + reinterpret_cast(output_grad->data()); + + bool reduce_all = ctx.Attr("reduce_all"); + auto reduce_dims = ctx.Attr>("dim"); + + std::vector xdims; + for (int i = 0; i < input->dims().size(); i++) { + xdims.push_back(input->dims()[i]); + } + std::vector ydims; + for (int i = 0; i < output_grad->dims().size(); i++) { + ydims.push_back(output_grad->dims()[i]); + } + + int reduce_numel = 1; + if (reduce_all) { + reduce_dims.clear(); + for (size_t d = 0; d < xdims.size(); ++d) { + reduce_dims.push_back(static_cast(d)); + } + } + for (auto& d : reduce_dims) { + if (d < 0) { + d = d + xdims.size(); + } + reduce_numel *= xdims[d]; + } + + float val = 1.0f / static_cast(reduce_numel); + + auto& dev_ctx = ctx.template device_context(); + + int r = xpu::constant(dev_ctx.x_context(), x_data, input->numel(), + static_cast(val)); + + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, + platform::errors::External( + "XPU constant kernel return wrong value[%d %s]", r, + XPUAPIErrorMsg[r])); + r = xpu::broadcast_mul(dev_ctx.x_context(), x_data, dy_data, x_data, xdims, + ydims); + + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, + platform::errors::External( + "XPU broadcast_mul kernel return wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); + } +}; + } // namespace operators } // namespace paddle @@ -54,4 +127,8 @@ REGISTER_OP_XPU_KERNEL( reduce_mean, ops::ReduceMeanXPUKernel); +REGISTER_OP_XPU_KERNEL( + reduce_mean_grad, + ops::ReduceMeanGradXPUKernel); + #endif diff --git a/paddle/fluid/operators/slice_op_xpu.cc b/paddle/fluid/operators/slice_op_xpu.cc index 5f98efe8e91..6ac1027b0ce 100644 --- a/paddle/fluid/operators/slice_op_xpu.cc +++ b/paddle/fluid/operators/slice_op_xpu.cc @@ -27,6 +27,8 @@ using Tensor = framework::Tensor; template class SliceXPUKernel : public framework::OpKernel { + using XPUType = typename XPUTypeTrait::Type; + public: void Compute(const framework::ExecutionContext& ctx) const override { auto in = ctx.Input("Input"); @@ -83,114 +85,93 @@ class SliceXPUKernel : public framework::OpKernel { } auto& dev_ctx = ctx.template device_context(); - auto* in_data = in->data(); - auto* out_data = out->mutable_data(ctx.GetPlace()); - int r = xpu::slice(dev_ctx.x_context(), in_data, out_data, shape, - starts_extension, ends_extension); - PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, - platform::errors::External("XPU slice kernel error!")); + const XPUType* in_data = reinterpret_cast(in->data()); + XPUType* out_data = + reinterpret_cast(out->mutable_data(ctx.GetPlace())); + int r = xpu::slice(dev_ctx.x_context(), in_data, out_data, shape, + starts_extension, ends_extension); + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External("XPU slice kernel return wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); } }; template class SliceGradXPUKernel : public framework::OpKernel { + using XPUType = typename XPUTypeTrait::Type; + public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* d_out = ctx.Input(framework::GradVarName("Out")); - auto* d_in = ctx.Output(framework::GradVarName("Input")); - d_in->mutable_data(ctx.GetPlace()); - - auto in_dims = d_in->dims(); - auto axes = ctx.Attr>("axes"); - auto starts = ctx.Attr>("starts"); - auto ends = ctx.Attr>("ends"); + auto* input = ctx.Input("Input"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dinput = ctx.Output(framework::GradVarName("Input")); + + auto axes_int = ctx.Attr>("axes"); + auto starts_int = ctx.Attr>("starts"); + auto ends_int = ctx.Attr>("ends"); + std::vector axes(axes_int.begin(), axes_int.end()); + std::vector starts(starts_int.begin(), starts_int.end()); + std::vector ends(ends_int.begin(), ends_int.end()); + + // Get the accurate attribute value of starts and ends + auto starts_tensor_list = ctx.MultiInput("StartsTensorList"); + if (ctx.HasInput("StartsTensor")) { + starts = GetDataFromTensor(ctx.Input("StartsTensor")); + } else if (starts_tensor_list.size() > 0) { + starts = GetDataFromTensorList(starts_tensor_list); + } - // prepare starts, ends on XPU - int dim_value = 0, start = 0, end = 0; - // If a negative value is passed for any of the start or end indices, - // it represents number of elements before the end of that dimension. - // If the value passed to start or end is larger than the n - // (the number of elements in this dimension), it represents n. - for (size_t i = 0; i < axes.size(); ++i) { - dim_value = in_dims[axes[i]]; - start = starts[i]; - end = ends[i]; - start = start < 0 ? (start + dim_value) : start; - end = end < 0 ? (end + dim_value) : end; - start = std::max(start, 0); - end = std::max(end, 0); - end = std::min(end, dim_value); - PADDLE_ENFORCE_GT(end, start, platform::errors::InvalidArgument( - "end should greater than start")); - starts[i] = start; - ends[i] = end; + auto ends_tensor_list = ctx.MultiInput("EndsTensorList"); + if (ctx.HasInput("EndsTensor")) { + ends = GetDataFromTensor(ctx.Input("EndsTensor")); + } else if (ends_tensor_list.size() > 0) { + ends = GetDataFromTensorList(ends_tensor_list); } - size_t shape_size = in_dims.size(); - // the slice XPU kernel require that the length of `start`, `end` must be - // equal - // to the dims size of input tensor, therefore, if shape_size > axes.size(), - // the `starts_extension` and `ends_extension` is necessary. - std::vector starts_extension(shape_size, 0); - std::vector ends_extension(shape_size, 0); - if (shape_size > axes.size()) { - for (size_t i = 0; i < shape_size; ++i) { - ends_extension[i] = in_dims[i]; - } - for (size_t i = 0; i < axes.size(); ++i) { - starts_extension[axes[i]] = starts[i]; - ends_extension[axes[i]] = ends[i]; + + const auto& in_dims = input->dims(); + int rank = in_dims.size(); + + std::vector pad_left(rank); + std::vector out_dims(rank); + std::vector pad_right(rank); + int cnt = 0; + for (int i = 0; i < in_dims.size(); ++i) { + int start = 0; + int end = in_dims[i]; + int axis = cnt < static_cast(axes.size()) ? axes[cnt] : -1; + if (axis == i) { + start = starts[cnt]; + if (start < 0) { + start = (start + in_dims[i]); + } + start = std::max(start, static_cast(0)); + end = ends[cnt]; + if (end < 0) { + end = (end + in_dims[i]); + } + end = std::min(end, static_cast(in_dims[i])); + cnt++; } - } - int* starts_device = nullptr; - int* ends_device = nullptr; - int* starts_host = - shape_size > axes.size() ? starts_extension.data() : starts.data(); - int* ends_host = - shape_size > axes.size() ? ends_extension.data() : ends.data(); - PADDLE_ENFORCE_EQ(xpu_malloc(reinterpret_cast(&starts_device), - shape_size * sizeof(int)), - XPU_SUCCESS, - platform::errors::External("XPU has no enough memory")); - PADDLE_ENFORCE_EQ(xpu_malloc(reinterpret_cast(&ends_device), - shape_size * sizeof(int)), - XPU_SUCCESS, - platform::errors::External("XPU has no enough memory")); - memory::Copy(BOOST_GET_CONST(platform::XPUPlace, ctx.GetPlace()), - starts_device, platform::CPUPlace(), starts_host, - shape_size * sizeof(int)); - memory::Copy(BOOST_GET_CONST(platform::XPUPlace, ctx.GetPlace()), - ends_device, platform::CPUPlace(), ends_host, - shape_size * sizeof(int)); - // prepare shape on XPU - std::vector shape(shape_size, 0); - for (size_t i = 0; i < shape_size; ++i) { - shape[i] = in_dims[i]; + pad_left[i] = start; + out_dims[i] = end - start; + pad_right[i] = in_dims[i] - out_dims[i] - pad_left[i]; } - int* shape_device = nullptr; - PADDLE_ENFORCE_EQ(xpu_malloc(reinterpret_cast(&shape_device), - shape_size * sizeof(int)), - XPU_SUCCESS, - platform::errors::External("XPU has no enough memory")); - memory::Copy(BOOST_GET_CONST(platform::XPUPlace, ctx.GetPlace()), - shape_device, platform::CPUPlace(), shape.data(), - shape_size * sizeof(int)); auto& dev_ctx = ctx.template device_context(); - int r = - xpu::slice_backward(dev_ctx.x_context(), shape_device, starts_device, - ends_device, shape_size, d_out->data(), - d_in->data(), d_in->numel(), d_out->numel()); - PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, - platform::errors::External("xpu slice kernel error")); - dev_ctx.Wait(); - // free device data - xpu_free(shape_device); - xpu_free(starts_device); - xpu_free(ends_device); + const XPUType* dout_data = + reinterpret_cast(dout->data()); + XPUType* din_data = + reinterpret_cast(dinput->mutable_data(ctx.GetPlace())); + int r = xpu::pad(dev_ctx.x_context(), dout_data, din_data, + out_dims, pad_left, pad_right, XPUType(0)); + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External("XPU pad kernel return wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); } }; - } // namespace operators } // namespace paddle @@ -198,8 +179,13 @@ namespace ops = paddle::operators; REGISTER_OP_XPU_KERNEL( slice, ops::SliceXPUKernel, - ops::SliceXPUKernel); + ops::SliceXPUKernel, + ops::SliceXPUKernel); REGISTER_OP_XPU_KERNEL( slice_grad, - ops::SliceGradXPUKernel); + ops::SliceGradXPUKernel, + ops::SliceGradXPUKernel, + ops::SliceGradXPUKernel); #endif diff --git a/paddle/fluid/platform/xpu/xpu2_op_list.h b/paddle/fluid/platform/xpu/xpu2_op_list.h index 651243a4dfe..5d45e5d9d50 100644 --- a/paddle/fluid/platform/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/xpu/xpu2_op_list.h @@ -109,7 +109,16 @@ XPUOpMap& get_kl2_ops() { pOpKernelType(vartype::FP16, XPUPlace())})}, {"iou_similarity", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"arg_max", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})} + {"arg_max", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"reduce_mean", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"reduce_mean_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"slice", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace())})}, + {"slice_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace())})}, // AddMore }; diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index c6367911b88..156ba07a4ce 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -313,6 +313,18 @@ def _current_expected_place(): "You are using GPU version Paddle, but your CUDA device is not set properly. CPU device will be used by default." ) _global_expected_place_ = core.CPUPlace() + elif core.is_compiled_with_xpu(): + try: + device_count = core.get_xpu_device_count() + except Exception as e: + device_count = 0 + if device_count > 0: + _global_expected_place_ = core.XPUPlace(0) + else: + warnings.warn( + "You are using XPU version Paddle, but your XPU device is not set properly. CPU device will be used by default." + ) + _global_expected_place_ = core.CPUPlace() else: _global_expected_place_ = core.CPUPlace() diff --git a/python/paddle/optimizer/adamw.py b/python/paddle/optimizer/adamw.py index f26ee80d0af..55aaac8dc48 100644 --- a/python/paddle/optimizer/adamw.py +++ b/python/paddle/optimizer/adamw.py @@ -190,9 +190,6 @@ class AdamW(Adam): self.type = "adamw" - if core.is_compiled_with_xpu(): - self.type = "adam" - # Use _auxiliary_vars together with _set_auxiliary_var/_get_auxiliary_var to achieve that. self._auxiliary_vars = dict() @@ -259,10 +256,6 @@ class AdamW(Adam): paddle.fluid.layers.assign(input=scaled_param, output=param) def _append_optimize_op(self, block, param_and_grad): - if paddle.is_compiled_with_xpu(): - self._append_decoupled_weight_decay(block, param_and_grad) - return super(AdamW, self)._append_optimize_op(block, param_and_grad) - assert isinstance(block, framework.Block) if isinstance(param_and_grad, dict): param_and_grad = self._update_param_group(param_and_grad) diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 71968d67ed6..72b6bd29fd9 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -104,9 +104,9 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True): if place is None: place = _current_expected_place() elif not isinstance(place, (core.Place, core.CPUPlace, core.CUDAPinnedPlace, - core.CUDAPlace, core.NPUPlace)): + core.CUDAPlace, core.NPUPlace, core.XPUPlace)): raise ValueError( - "'place' must be any of paddle.Place, paddle.CPUPlace, paddle.CUDAPinnedPlace, paddle.CUDAPlace, paddle.NPUPlace" + "'place' must be any of paddle.Place, paddle.CPUPlace, paddle.CUDAPinnedPlace, paddle.CUDAPlace, paddle.NPUPlace, paddle.XPUPlace" ) #Todo(zhouwei): Support allocate tensor on any other specified card -- GitLab