未验证 提交 62bff0e0 编写于 作者: L Leo Guo 提交者: GitHub

Add data type of int, int64 for add kernel. Modify the code style of (#50443)

instance_norm_grad kernel. Fix bugs that the data type of input is different from output in reduce_sum kernel. test=kunlun
上级 f7267412
...@@ -120,6 +120,7 @@ XPUOpMap& get_kl2_ops() { ...@@ -120,6 +120,7 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::FLOAT16, phi::DataType::FLOAT16,
phi::DataType::FLOAT64, phi::DataType::FLOAT64,
phi::DataType::BOOL, phi::DataType::BOOL,
phi::DataType::INT8,
phi::DataType::UINT8, phi::DataType::UINT8,
phi::DataType::INT64, phi::DataType::INT64,
phi::DataType::INT32})}, phi::DataType::INT32})},
...@@ -286,6 +287,7 @@ XPUOpMap& get_kl2_ops() { ...@@ -286,6 +287,7 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({phi::DataType::INT64, XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32, phi::DataType::INT32,
phi::DataType::INT16, phi::DataType::INT16,
phi::DataType::INT8,
phi::DataType::UINT8, phi::DataType::UINT8,
phi::DataType::BOOL, phi::DataType::BOOL,
phi::DataType::FLOAT32, phi::DataType::FLOAT32,
......
...@@ -335,4 +335,20 @@ namespace phi { ...@@ -335,4 +335,20 @@ namespace phi {
} \ } \
}() }()
#define PD_VISIT_XPU_REDUCE_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
PD_PRIVATE_CASE_TYPE(NAME, ::phi::DataType::INT8, int8_t, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::phi::DataType::INT32, int32_t, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::phi::DataType::INT64, int64_t, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE( \
NAME, ::phi::DataType::FLOAT16, phi::float16, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::phi::DataType::FLOAT32, float, __VA_ARGS__) \
default: \
PADDLE_THROW(phi::errors::InvalidArgument( \
"Invalid enum data type `%d`.", static_cast<int>(__dtype__))); \
} \
}()
} // namespace phi } // namespace phi
...@@ -88,6 +88,7 @@ PD_REGISTER_KERNEL(full, ...@@ -88,6 +88,7 @@ PD_REGISTER_KERNEL(full,
phi::FullKernel, phi::FullKernel,
float, float,
double, double,
int8_t,
uint8_t, uint8_t,
int16_t, int16_t,
int, int,
......
...@@ -304,9 +304,14 @@ PD_REGISTER_KERNEL(divide, ...@@ -304,9 +304,14 @@ PD_REGISTER_KERNEL(divide,
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
divide, XPU, ALL_LAYOUT, phi::DivideKernel, phi::dtype::float16, float) {} divide, XPU, ALL_LAYOUT, phi::DivideKernel, phi::dtype::float16, float) {}
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(add,
add, XPU, ALL_LAYOUT, phi::AddKernel, phi::dtype::float16, float, int64_t) { XPU,
} ALL_LAYOUT,
phi::AddKernel,
phi::dtype::float16,
float,
int,
int64_t) {}
PD_REGISTER_KERNEL(multiply, PD_REGISTER_KERNEL(multiply,
XPU, XPU,
......
...@@ -72,6 +72,13 @@ void CastKernel(const Context& dev_ctx, ...@@ -72,6 +72,13 @@ void CastKernel(const Context& dev_ctx,
dev_ctx.template Alloc<bool>(out), dev_ctx.template Alloc<bool>(out),
numel); numel);
break; break;
case phi::DataType::INT8:
r = xpu::cast<XPUInTDType, int8_t>(
dev_ctx.x_context(),
reinterpret_cast<const XPUInTDType*>(in_data),
dev_ctx.template Alloc<int8_t>(out),
numel);
break;
case phi::DataType::UINT8: case phi::DataType::UINT8:
r = xpu::cast<XPUInTDType, uint8_t>( r = xpu::cast<XPUInTDType, uint8_t>(
dev_ctx.x_context(), dev_ctx.x_context(),
...@@ -104,6 +111,7 @@ PD_REGISTER_KERNEL(cast, ...@@ -104,6 +111,7 @@ PD_REGISTER_KERNEL(cast,
phi::dtype::float16, phi::dtype::float16,
int64_t, int64_t,
bool, bool,
int8_t,
uint8_t, uint8_t,
double) { double) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
......
...@@ -119,6 +119,7 @@ PD_REGISTER_KERNEL(full, ...@@ -119,6 +119,7 @@ PD_REGISTER_KERNEL(full,
ALL_LAYOUT, ALL_LAYOUT,
phi::FullKernel, phi::FullKernel,
float, float,
int8_t,
uint8_t, uint8_t,
int16_t, int16_t,
int, int,
......
...@@ -77,27 +77,45 @@ void InstanceNormGradKernel(const Context& dev_ctx, ...@@ -77,27 +77,45 @@ void InstanceNormGradKernel(const Context& dev_ctx,
scale_ptr->dims())); scale_ptr->dims()));
} }
DenseTensor scale_tmp; xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
float* scale_ptr_data_tmp;
int r; int r;
if (!scale_ptr) { if (!scale_ptr) {
scale_tmp.Resize({C}); scale_ptr_data_tmp = RAII_GUARD.alloc_l3_or_gm<float>(C);
dev_ctx.template Alloc<T>(&scale_tmp);
r = xpu::constant(dev_ctx.x_context(), r = xpu::constant(dev_ctx.x_context(),
reinterpret_cast<XPUType*>(scale_tmp.data<T>()), reinterpret_cast<float*>(scale_ptr_data_tmp),
scale_tmp.numel(), C,
static_cast<XPUType>(1)); static_cast<float>(1));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
} }
auto scale_ptr_tmp = scale_ptr ? scale_ptr : &scale_tmp; auto scale_ptr_data =
scale_ptr ? scale_ptr->data<float>() : scale_ptr_data_tmp;
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); if ((H * W * D) == 1) {
r = xpu::copy(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(d_y.data<T>()),
reinterpret_cast<XPUType*>(d_x->data<T>()),
d_y.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");
r = xpu::constant(dev_ctx.x_context(),
reinterpret_cast<float*>(d_scale),
C,
static_cast<float>(0));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
r = xpu::constant(dev_ctx.x_context(),
reinterpret_cast<float*>(d_bias),
C,
static_cast<float>(0));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
return;
}
auto d_x_data = auto d_x_data =
d_x ? d_x->data<T>() : RAII_GUARD.alloc_l3_or_gm<T>(x.numel()); d_x ? d_x->data<T>() : RAII_GUARD.alloc_l3_or_gm<T>(x.numel());
r = xpu::instance_norm_grad(dev_ctx.x_context(), r = xpu::instance_norm_grad(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()), reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<const XPUType*>(d_y.data<T>()), reinterpret_cast<const XPUType*>(d_y.data<T>()),
reinterpret_cast<XPUType*>(d_x_data), reinterpret_cast<XPUType*>(d_x_data),
scale_ptr_tmp->data<float>(), scale_ptr_data,
saved_mean.data<float>(), saved_mean.data<float>(),
saved_variance.data<float>(), saved_variance.data<float>(),
d_scale_data, d_scale_data,
......
...@@ -19,6 +19,10 @@ ...@@ -19,6 +19,10 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/xpu/reduce_util.h"
namespace phi { namespace phi {
template <typename Context, typename T> template <typename Context, typename T>
...@@ -82,4 +86,89 @@ int XPUReduce(const Context& dev_ctx, ...@@ -82,4 +86,89 @@ int XPUReduce(const Context& dev_ctx,
return r; return r;
} }
template <typename DeviceContext, typename T, typename OutT, typename Functor>
void ReduceKernelImpl(const DeviceContext& dev_ctx,
const phi::DenseTensor& input,
phi::DenseTensor* output,
const std::vector<int>& xdims,
const std::vector<int>& reduce_dims) {
using XPUType = typename XPUTypeTrait<OutT>::Type;
dev_ctx.template Alloc<OutT>(output);
const auto* x_data = input.data<OutT>();
auto* y_data = output->data<OutT>();
if (reduce_dims.size() == 0) {
int r = xpu::copy<XPUType>(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x_data),
reinterpret_cast<XPUType*>(y_data),
input.numel() * sizeof(T));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");
} else {
Functor func;
func(dev_ctx.x_context(), x_data, y_data, xdims, reduce_dims);
}
}
template <typename DeviceContext, typename T, typename Functor>
void XPUReduce(const DeviceContext& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DataType out_dtype,
DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
const auto& input_dim_size = x.dims().size();
std::vector<int> true_dims;
for (size_t i = 0; i < dims.size(); ++i) {
if (dims[i] < 0) {
true_dims.push_back(dims[i] + input_dim_size);
} else {
true_dims.push_back(dims[i]);
}
}
std::vector<int> reduce_dims;
std::vector<int> xdims((input_dim_size));
for (int i = 0; i < input_dim_size; ++i) {
xdims[i] = x.dims()[i];
}
if (reduce_all) {
for (int i = 0; i < input_dim_size; ++i) {
reduce_dims.push_back(i);
}
} else {
std::set<int> dims_set(true_dims.begin(), true_dims.end());
for (auto i = 0; i < input_dim_size; i++) {
if (dims_set.find(i) != dims_set.end()) {
if (x.dims()[i] != 1) {
reduce_dims.push_back(i);
}
}
}
}
// no need to cast dtype
if (out_dtype == phi::DataType::UNDEFINED || out_dtype == x.dtype()) {
// do reduce sum
PD_VISIT_XPU_REDUCE_TYPES(
x.dtype(), "ReduceKernelImpl", ([&] {
phi::ReduceKernelImpl<DeviceContext, T, data_t, Functor>(
dev_ctx, x, out, xdims, reduce_dims);
}));
} else {
// cast x tensor to out_dtype
auto tmp_tensor = phi::Cast<T, DeviceContext>(dev_ctx, x, out_dtype);
// do reduce sum
PD_VISIT_XPU_REDUCE_TYPES(
out_dtype, "ReduceKernelImpl", ([&] {
phi::ReduceKernelImpl<DeviceContext, T, data_t, Functor>(
dev_ctx, tmp_tensor, out, xdims, reduce_dims);
}));
if (dev_ctx.x_context()->xpu_stream) {
dev_ctx.Wait();
}
}
}
} // namespace phi } // namespace phi
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
#include "paddle/phi/kernels/reduce_sum_kernel.h" #include "paddle/phi/kernels/reduce_sum_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/backends/xpu/xpu_context.h" #include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/xpu/reduce.h" #include "paddle/phi/kernels/xpu/reduce.h"
...@@ -29,23 +28,11 @@ void SumRawKernel(const Context& dev_ctx, ...@@ -29,23 +28,11 @@ void SumRawKernel(const Context& dev_ctx,
bool reduce_all, bool reduce_all,
DataType out_dtype, DataType out_dtype,
DenseTensor* out) { DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all); if (out_dtype == DataType::UNDEFINED && out->dtype() != x.dtype()) {
using XPUType = typename XPUTypeTrait<T>::Type; out_dtype = out->dtype();
}
auto f = [](xpu::Context* ctx, XPUReduce<Context, T, phi::SumFunctor>(
const T* x, dev_ctx, x, dims.GetData(), keep_dim, reduce_all, out_dtype, out);
T* y,
const std::vector<int>& xdims,
const std::vector<int>& reduce_dims) {
return xpu::reduce_sum<XPUType>(ctx,
reinterpret_cast<const XPUType*>(x),
reinterpret_cast<XPUType*>(y),
xdims,
reduce_dims);
};
int r = XPUReduce<Context, T>(
dev_ctx, x, dims.GetData(), keep_dim, reduce_all, out, f);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_sum");
} }
} // namespace phi } // namespace phi
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
namespace phi {
//////// Sum Functor ///////
struct SumFunctor {
template <typename DeviceContext, typename X, typename Y>
void operator()(const DeviceContext& ctx,
const X* x,
Y* y,
const std::vector<int>& xdims,
const std::vector<int>& reduce_dims) {
using XPUType = typename XPUTypeTrait<X>::Type;
int r = xpu::reduce_sum<XPUType>(ctx,
reinterpret_cast<const XPUType*>(x),
reinterpret_cast<XPUType*>(y),
xdims,
reduce_dims);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_sum");
}
};
} // namespace phi
...@@ -35,6 +35,7 @@ typeid_dict = { ...@@ -35,6 +35,7 @@ typeid_dict = {
'float32': int(core.VarDesc.VarType.FP32), 'float32': int(core.VarDesc.VarType.FP32),
'float16': int(core.VarDesc.VarType.FP16), 'float16': int(core.VarDesc.VarType.FP16),
'bool': int(core.VarDesc.VarType.BOOL), 'bool': int(core.VarDesc.VarType.BOOL),
'int8': int(core.VarDesc.VarType.INT8),
'uint8': int(core.VarDesc.VarType.UINT8), 'uint8': int(core.VarDesc.VarType.UINT8),
'float64': int(core.VarDesc.VarType.FP64), 'float64': int(core.VarDesc.VarType.FP64),
} }
...@@ -53,6 +54,7 @@ class XPUTestCastOp(XPUOpTestWrapper): ...@@ -53,6 +54,7 @@ class XPUTestCastOp(XPUOpTestWrapper):
'float32', 'float32',
'int32', 'int32',
'int64', 'int64',
'int8',
'uint8', 'uint8',
'bool', 'bool',
'float64', 'float64',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册