diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index b21cb6b55a338719fa3aac0bf6549a62f897463f..27a4e054a7bb18f4b5c999b77993cfcaa6c4585e 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -120,6 +120,7 @@ XPUOpMap& get_kl2_ops() { phi::DataType::FLOAT16, phi::DataType::FLOAT64, phi::DataType::BOOL, + phi::DataType::INT8, phi::DataType::UINT8, phi::DataType::INT64, phi::DataType::INT32})}, @@ -286,6 +287,7 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({phi::DataType::INT64, phi::DataType::INT32, phi::DataType::INT16, + phi::DataType::INT8, phi::DataType::UINT8, phi::DataType::BOOL, phi::DataType::FLOAT32, diff --git a/paddle/phi/core/visit_type.h b/paddle/phi/core/visit_type.h index cc12be86e8dd21d544509662b04c2069d24f0173..8343343a3611ce0b5969e98743aa7be0f50a714c 100644 --- a/paddle/phi/core/visit_type.h +++ b/paddle/phi/core/visit_type.h @@ -335,4 +335,20 @@ namespace phi { } \ }() +#define PD_VISIT_XPU_REDUCE_TYPES(TYPE, NAME, ...) \ + [&] { \ + const auto& __dtype__ = TYPE; \ + switch (__dtype__) { \ + PD_PRIVATE_CASE_TYPE(NAME, ::phi::DataType::INT8, int8_t, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::phi::DataType::INT32, int32_t, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::phi::DataType::INT64, int64_t, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::phi::DataType::FLOAT16, phi::float16, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::phi::DataType::FLOAT32, float, __VA_ARGS__) \ + default: \ + PADDLE_THROW(phi::errors::InvalidArgument( \ + "Invalid enum data type `%d`.", static_cast(__dtype__))); \ + } \ + }() + } // namespace phi diff --git a/paddle/phi/kernels/cpu/full_kernel.cc b/paddle/phi/kernels/cpu/full_kernel.cc index e7dd6249f3644c53db9a6b394020a6b9c3438d6f..d9ab771664a8f4fb8156346a1e1985977e1728f0 100644 --- a/paddle/phi/kernels/cpu/full_kernel.cc +++ b/paddle/phi/kernels/cpu/full_kernel.cc @@ -88,6 +88,7 @@ PD_REGISTER_KERNEL(full, phi::FullKernel, float, double, + int8_t, uint8_t, int16_t, int, diff --git a/paddle/phi/kernels/elementwise_kernel.cc b/paddle/phi/kernels/elementwise_kernel.cc index b96b50d857ff6fe76b2b4608fd859808fbeea5ea..98d76c2d944f3d4219c3493b3ed4c06405d58359 100644 --- a/paddle/phi/kernels/elementwise_kernel.cc +++ b/paddle/phi/kernels/elementwise_kernel.cc @@ -304,9 +304,14 @@ PD_REGISTER_KERNEL(divide, PD_REGISTER_KERNEL( divide, XPU, ALL_LAYOUT, phi::DivideKernel, phi::dtype::float16, float) {} -PD_REGISTER_KERNEL( - add, XPU, ALL_LAYOUT, phi::AddKernel, phi::dtype::float16, float, int64_t) { -} +PD_REGISTER_KERNEL(add, + XPU, + ALL_LAYOUT, + phi::AddKernel, + phi::dtype::float16, + float, + int, + int64_t) {} PD_REGISTER_KERNEL(multiply, XPU, diff --git a/paddle/phi/kernels/xpu/cast_kernel.cc b/paddle/phi/kernels/xpu/cast_kernel.cc index 8757e7344356905151ef139a904ae0ce823cdfd5..74e2a622dba865fc45b9af38b870e467f2d84cae 100644 --- a/paddle/phi/kernels/xpu/cast_kernel.cc +++ b/paddle/phi/kernels/xpu/cast_kernel.cc @@ -72,6 +72,13 @@ void CastKernel(const Context& dev_ctx, dev_ctx.template Alloc(out), numel); break; + case phi::DataType::INT8: + r = xpu::cast( + dev_ctx.x_context(), + reinterpret_cast(in_data), + dev_ctx.template Alloc(out), + numel); + break; case phi::DataType::UINT8: r = xpu::cast( dev_ctx.x_context(), @@ -104,6 +111,7 @@ PD_REGISTER_KERNEL(cast, phi::dtype::float16, int64_t, bool, + int8_t, uint8_t, double) { kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); diff --git a/paddle/phi/kernels/xpu/full_kernel.cc b/paddle/phi/kernels/xpu/full_kernel.cc index d0b6cfda981df024aa2b1e60fa38bfa6ddb21a2a..f1754b0631ad4f87ef729857bfe106d6ee6bfdfe 100644 --- a/paddle/phi/kernels/xpu/full_kernel.cc +++ b/paddle/phi/kernels/xpu/full_kernel.cc @@ -119,6 +119,7 @@ PD_REGISTER_KERNEL(full, ALL_LAYOUT, phi::FullKernel, float, + int8_t, uint8_t, int16_t, int, diff --git a/paddle/phi/kernels/xpu/instance_norm_grad_kernel.cc b/paddle/phi/kernels/xpu/instance_norm_grad_kernel.cc index 641794dab0a4387cc8d2a2bf4787b751a338db46..dba0e2ccfd76514a422b7f49e61e9089dee96991 100644 --- a/paddle/phi/kernels/xpu/instance_norm_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/instance_norm_grad_kernel.cc @@ -77,27 +77,45 @@ void InstanceNormGradKernel(const Context& dev_ctx, scale_ptr->dims())); } - DenseTensor scale_tmp; + xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); + float* scale_ptr_data_tmp; int r; if (!scale_ptr) { - scale_tmp.Resize({C}); - dev_ctx.template Alloc(&scale_tmp); + scale_ptr_data_tmp = RAII_GUARD.alloc_l3_or_gm(C); r = xpu::constant(dev_ctx.x_context(), - reinterpret_cast(scale_tmp.data()), - scale_tmp.numel(), - static_cast(1)); + reinterpret_cast(scale_ptr_data_tmp), + C, + static_cast(1)); PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); } - auto scale_ptr_tmp = scale_ptr ? scale_ptr : &scale_tmp; + auto scale_ptr_data = + scale_ptr ? scale_ptr->data() : scale_ptr_data_tmp; - xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); + if ((H * W * D) == 1) { + r = xpu::copy(dev_ctx.x_context(), + reinterpret_cast(d_y.data()), + reinterpret_cast(d_x->data()), + d_y.numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy"); + r = xpu::constant(dev_ctx.x_context(), + reinterpret_cast(d_scale), + C, + static_cast(0)); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); + r = xpu::constant(dev_ctx.x_context(), + reinterpret_cast(d_bias), + C, + static_cast(0)); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); + return; + } auto d_x_data = d_x ? d_x->data() : RAII_GUARD.alloc_l3_or_gm(x.numel()); r = xpu::instance_norm_grad(dev_ctx.x_context(), reinterpret_cast(x.data()), reinterpret_cast(d_y.data()), reinterpret_cast(d_x_data), - scale_ptr_tmp->data(), + scale_ptr_data, saved_mean.data(), saved_variance.data(), d_scale_data, diff --git a/paddle/phi/kernels/xpu/reduce.h b/paddle/phi/kernels/xpu/reduce.h index 02369e268676be233cfb802a36c647f465084ce7..a9ba6c1ac1347ac459ed9dbfe8fd53e3843d0b82 100644 --- a/paddle/phi/kernels/xpu/reduce.h +++ b/paddle/phi/kernels/xpu/reduce.h @@ -19,6 +19,10 @@ #include #include +#include "paddle/phi/core/visit_type.h" +#include "paddle/phi/kernels/cast_kernel.h" +#include "paddle/phi/kernels/xpu/reduce_util.h" + namespace phi { template @@ -82,4 +86,89 @@ int XPUReduce(const Context& dev_ctx, return r; } +template +void ReduceKernelImpl(const DeviceContext& dev_ctx, + const phi::DenseTensor& input, + phi::DenseTensor* output, + const std::vector& xdims, + const std::vector& reduce_dims) { + using XPUType = typename XPUTypeTrait::Type; + dev_ctx.template Alloc(output); + const auto* x_data = input.data(); + auto* y_data = output->data(); + if (reduce_dims.size() == 0) { + int r = xpu::copy(dev_ctx.x_context(), + reinterpret_cast(x_data), + reinterpret_cast(y_data), + input.numel() * sizeof(T)); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy"); + } else { + Functor func; + func(dev_ctx.x_context(), x_data, y_data, xdims, reduce_dims); + } +} + +template +void XPUReduce(const DeviceContext& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DataType out_dtype, + DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); + + const auto& input_dim_size = x.dims().size(); + std::vector true_dims; + for (size_t i = 0; i < dims.size(); ++i) { + if (dims[i] < 0) { + true_dims.push_back(dims[i] + input_dim_size); + } else { + true_dims.push_back(dims[i]); + } + } + std::vector reduce_dims; + std::vector xdims((input_dim_size)); + for (int i = 0; i < input_dim_size; ++i) { + xdims[i] = x.dims()[i]; + } + if (reduce_all) { + for (int i = 0; i < input_dim_size; ++i) { + reduce_dims.push_back(i); + } + } else { + std::set dims_set(true_dims.begin(), true_dims.end()); + for (auto i = 0; i < input_dim_size; i++) { + if (dims_set.find(i) != dims_set.end()) { + if (x.dims()[i] != 1) { + reduce_dims.push_back(i); + } + } + } + } + // no need to cast dtype + if (out_dtype == phi::DataType::UNDEFINED || out_dtype == x.dtype()) { + // do reduce sum + PD_VISIT_XPU_REDUCE_TYPES( + x.dtype(), "ReduceKernelImpl", ([&] { + phi::ReduceKernelImpl( + dev_ctx, x, out, xdims, reduce_dims); + })); + } else { + // cast x tensor to out_dtype + auto tmp_tensor = phi::Cast(dev_ctx, x, out_dtype); + + // do reduce sum + PD_VISIT_XPU_REDUCE_TYPES( + out_dtype, "ReduceKernelImpl", ([&] { + phi::ReduceKernelImpl( + dev_ctx, tmp_tensor, out, xdims, reduce_dims); + })); + + if (dev_ctx.x_context()->xpu_stream) { + dev_ctx.Wait(); + } + } +} + } // namespace phi diff --git a/paddle/phi/kernels/xpu/reduce_sum_kernel.cc b/paddle/phi/kernels/xpu/reduce_sum_kernel.cc index 5d76926ea958e4a055b6e8bc309e042f1e2e2f9f..48a339ab51ce424bb3db31dc326e006bd2c4b4ed 100644 --- a/paddle/phi/kernels/xpu/reduce_sum_kernel.cc +++ b/paddle/phi/kernels/xpu/reduce_sum_kernel.cc @@ -14,7 +14,6 @@ #include "paddle/phi/kernels/reduce_sum_kernel.h" -#include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/backends/xpu/xpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/xpu/reduce.h" @@ -29,23 +28,11 @@ void SumRawKernel(const Context& dev_ctx, bool reduce_all, DataType out_dtype, DenseTensor* out) { - reduce_all = recompute_reduce_all(x, dims, reduce_all); - using XPUType = typename XPUTypeTrait::Type; - - auto f = [](xpu::Context* ctx, - const T* x, - T* y, - const std::vector& xdims, - const std::vector& reduce_dims) { - return xpu::reduce_sum(ctx, - reinterpret_cast(x), - reinterpret_cast(y), - xdims, - reduce_dims); - }; - int r = XPUReduce( - dev_ctx, x, dims.GetData(), keep_dim, reduce_all, out, f); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_sum"); + if (out_dtype == DataType::UNDEFINED && out->dtype() != x.dtype()) { + out_dtype = out->dtype(); + } + XPUReduce( + dev_ctx, x, dims.GetData(), keep_dim, reduce_all, out_dtype, out); } } // namespace phi diff --git a/paddle/phi/kernels/xpu/reduce_util.h b/paddle/phi/kernels/xpu/reduce_util.h new file mode 100644 index 0000000000000000000000000000000000000000..cd624cc1ef1f0da56c1386d0a870ee08fe138bb7 --- /dev/null +++ b/paddle/phi/kernels/xpu/reduce_util.h @@ -0,0 +1,39 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/backends/xpu/xpu_context.h" + +namespace phi { + +//////// Sum Functor /////// +struct SumFunctor { + template + void operator()(const DeviceContext& ctx, + const X* x, + Y* y, + const std::vector& xdims, + const std::vector& reduce_dims) { + using XPUType = typename XPUTypeTrait::Type; + int r = xpu::reduce_sum(ctx, + reinterpret_cast(x), + reinterpret_cast(y), + xdims, + reduce_dims); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_sum"); + } +}; +} // namespace phi diff --git a/python/paddle/fluid/tests/unittests/xpu/test_cast_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_cast_op_xpu.py index 99f29e5f86650b35b9ef79dd5b08f43f78d9515d..e013432d13b97695c42817badbf20024790cb8a9 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_cast_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_cast_op_xpu.py @@ -35,6 +35,7 @@ typeid_dict = { 'float32': int(core.VarDesc.VarType.FP32), 'float16': int(core.VarDesc.VarType.FP16), 'bool': int(core.VarDesc.VarType.BOOL), + 'int8': int(core.VarDesc.VarType.INT8), 'uint8': int(core.VarDesc.VarType.UINT8), 'float64': int(core.VarDesc.VarType.FP64), } @@ -53,6 +54,7 @@ class XPUTestCastOp(XPUOpTestWrapper): 'float32', 'int32', 'int64', + 'int8', 'uint8', 'bool', 'float64',