未验证 提交 95c3d613 编写于 作者: Y Yiqun Liu 提交者: GitHub

Cherry pick the support of bfloat16 for several operators. (#52608)

* Register exp/expm1/logit bf16 activation op kernels (#48702)

* register more bf16 ops

* update to register coresponding backward ops

* Addition of bf16 type support for Compare OP  (#46413)

* first commit

* clarify the quotes

* change code style format

* support bfloat16

* add bfloat16 support for more ops (#48272)

* [Bfloat16]register bfloat16 datatype for squared l2 norm (#50908)

* Sync the pull request #51903.

* Add some header files back.

* modify cmake file for cuda11.8 compile (#49020)

* modify cmake file for cuda11.8 compile

* add op_library(fused_embedding_eltwise_layernorm_op DEPS bert_encoder_functor)

* Fix compling error.

* Cherry-pick pull request #51396.

---------
Co-authored-by: Nsneaxiy <32832641+sneaxiy@users.noreply.github.com>
Co-authored-by: Nlimingshu <61349199+JamesLim-sy@users.noreply.github.com>
Co-authored-by: shaojie_wang's avatarShaojie WANG <wsjmessi@163.com>
Co-authored-by: Nzqw_1997 <118182234+zhengqiwen1997@users.noreply.github.com>
上级 73473ac2
...@@ -68,6 +68,11 @@ __device__ __inline__ void load_data_upper_tri(plat::float16* dst, ...@@ -68,6 +68,11 @@ __device__ __inline__ void load_data_upper_tri(plat::float16* dst,
*(reinterpret_cast<float2*>(dst)) = *(reinterpret_cast<const float2*>(src)); *(reinterpret_cast<float2*>(dst)) = *(reinterpret_cast<const float2*>(src));
} }
__device__ __inline__ void load_data_upper_tri(plat::bfloat16* dst,
const plat::bfloat16* src) {
*(reinterpret_cast<float2*>(dst)) = *(reinterpret_cast<const float2*>(src));
}
__device__ __inline__ void load_data_upper_tri(float* dst, const float* src) { __device__ __inline__ void load_data_upper_tri(float* dst, const float* src) {
*(reinterpret_cast<float4*>(dst)) = *(reinterpret_cast<const float4*>(src)); *(reinterpret_cast<float4*>(dst)) = *(reinterpret_cast<const float4*>(src));
} }
...@@ -76,6 +81,10 @@ __device__ __inline__ void load_zero_vector_upper_tri(plat::float16* dst) { ...@@ -76,6 +81,10 @@ __device__ __inline__ void load_zero_vector_upper_tri(plat::float16* dst) {
*(reinterpret_cast<float2*>(dst)) = make_float2(0.0f, 0.0f); *(reinterpret_cast<float2*>(dst)) = make_float2(0.0f, 0.0f);
} }
__device__ __inline__ void load_zero_vector_upper_tri(plat::bfloat16* dst) {
*(reinterpret_cast<float2*>(dst)) = make_float2(0.0f, 0.0f);
}
__device__ __inline__ void load_zero_vector_upper_tri(float* dst) { __device__ __inline__ void load_zero_vector_upper_tri(float* dst) {
*(reinterpret_cast<float4*>(dst)) = make_float4(0.0f, 0.0f, 0.0f, 0.0f); *(reinterpret_cast<float4*>(dst)) = make_float4(0.0f, 0.0f, 0.0f, 0.0f);
} }
...@@ -595,8 +604,11 @@ namespace plat = paddle::platform; ...@@ -595,8 +604,11 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
fused_softmax_mask_upper_triangle, fused_softmax_mask_upper_triangle,
ops::SoftmaxMaskFuseUpperTriangleKernel<phi::GPUContext, plat::float16>, ops::SoftmaxMaskFuseUpperTriangleKernel<phi::GPUContext, plat::float16>,
ops::SoftmaxMaskFuseUpperTriangleKernel<phi::GPUContext, plat::bfloat16>,
ops::SoftmaxMaskFuseUpperTriangleKernel<phi::GPUContext, float>); ops::SoftmaxMaskFuseUpperTriangleKernel<phi::GPUContext, float>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
fused_softmax_mask_upper_triangle_grad, fused_softmax_mask_upper_triangle_grad,
ops::SoftmaxMaskFuseUpperTriangleGradKernel<phi::GPUContext, plat::float16>, ops::SoftmaxMaskFuseUpperTriangleGradKernel<phi::GPUContext, plat::float16>,
ops::SoftmaxMaskFuseUpperTriangleGradKernel<phi::GPUContext,
plat::bfloat16>,
ops::SoftmaxMaskFuseUpperTriangleGradKernel<phi::GPUContext, float>); ops::SoftmaxMaskFuseUpperTriangleGradKernel<phi::GPUContext, float>);
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include "math.h" // NOLINT #include "math.h" // NOLINT
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/hostdevice.h"
...@@ -33,6 +34,10 @@ inline HOSTDEVICE platform::float16 real_log(platform::float16 x) { ...@@ -33,6 +34,10 @@ inline HOSTDEVICE platform::float16 real_log(platform::float16 x) {
return static_cast<platform::float16>(::logf(static_cast<float>(x))); return static_cast<platform::float16>(::logf(static_cast<float>(x)));
} }
inline HOSTDEVICE phi::dtype::bfloat16 real_log(phi::dtype::bfloat16 x) {
return static_cast<phi::dtype::bfloat16>(::logf(static_cast<float>(x)));
}
inline HOSTDEVICE float real_log(float x) { return ::logf(x); } inline HOSTDEVICE float real_log(float x) { return ::logf(x); }
inline HOSTDEVICE double real_log(double x) { return ::log(x); } inline HOSTDEVICE double real_log(double x) { return ::log(x); }
......
...@@ -13,9 +13,11 @@ See the License for the specific language governing permissions and ...@@ -13,9 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/operators/math.h" #include "paddle/fluid/operators/math.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
...@@ -152,7 +154,10 @@ void CrossEntropyFunctor<DeviceContext, T>::operator()( ...@@ -152,7 +154,10 @@ void CrossEntropyFunctor<DeviceContext, T>::operator()(
template class CrossEntropyFunctor<phi::GPUContext, float>; template class CrossEntropyFunctor<phi::GPUContext, float>;
template class CrossEntropyFunctor<phi::GPUContext, double>; template class CrossEntropyFunctor<phi::GPUContext, double>;
template class CrossEntropyFunctor<phi::GPUContext, platform::float16>; template class CrossEntropyFunctor<phi::GPUContext, phi::dtype::float16>;
#if defined(PADDLE_WITH_CUDA) && CUDNN_VERSION_MIN(8, 1, 0)
template class CrossEntropyFunctor<phi::GPUContext, phi::dtype::bfloat16>;
#endif
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -17,7 +17,8 @@ limitations under the License. */ ...@@ -17,7 +17,8 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/hostdevice.h"
namespace paddle { namespace paddle {
...@@ -46,14 +47,30 @@ struct TolerableValue { ...@@ -46,14 +47,30 @@ struct TolerableValue {
// Also. In standard implementation of cross entropy, other // Also. In standard implementation of cross entropy, other
// framework not has the ValueClipping. // framework not has the ValueClipping.
template <> template <>
struct TolerableValue<platform::float16> { struct TolerableValue<phi::dtype::float16> {
HOSTDEVICE platform::float16 operator()(const platform::float16& x) const { HOSTDEVICE phi::dtype::float16 operator()(
if (platform::isfinite(x)) const phi::dtype::float16& x) const {
if (phi::dtype::isfinite(x)) {
return x; return x;
else if (x > static_cast<platform::float16>(0)) } else if (x > static_cast<phi::dtype::float16>(0)) {
return std::numeric_limits<platform::float16>::max(); return std::numeric_limits<phi::dtype::float16>::max();
else } else {
return std::numeric_limits<platform::float16>::min(); return std::numeric_limits<phi::dtype::float16>::min();
}
}
};
template <>
struct TolerableValue<phi::dtype::bfloat16> {
HOSTDEVICE phi::dtype::bfloat16 operator()(
const phi::dtype::bfloat16& x) const {
if (phi::dtype::isfinite(x)) {
return x;
} else if (x > static_cast<phi::dtype::bfloat16>(0)) {
return std::numeric_limits<phi::dtype::bfloat16>::max();
} else {
return std::numeric_limits<phi::dtype::bfloat16>::min();
}
} }
}; };
......
...@@ -59,7 +59,8 @@ PD_REGISTER_KERNEL(full_batch_size_like, ...@@ -59,7 +59,8 @@ PD_REGISTER_KERNEL(full_batch_size_like,
int, int,
int64_t, int64_t,
bool, bool,
phi::dtype::float16) { phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
} }
#endif #endif
...@@ -370,7 +370,8 @@ PD_REGISTER_KERNEL(exp_grad, ...@@ -370,7 +370,8 @@ PD_REGISTER_KERNEL(exp_grad,
double, double,
int, int,
int64_t, int64_t,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_ACTIVATION_GRAD_KERNEL(soft_shrink_grad, SoftShrinkGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(soft_shrink_grad, SoftShrinkGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_shrink_grad, HardShrinkGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_shrink_grad, HardShrinkGradKernel)
...@@ -385,7 +386,8 @@ PD_REGISTER_KERNEL(expm1_grad, ...@@ -385,7 +386,8 @@ PD_REGISTER_KERNEL(expm1_grad,
phi::Expm1GradKernel, phi::Expm1GradKernel,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(logit_grad, PD_REGISTER_KERNEL(logit_grad,
GPU, GPU,
...@@ -393,7 +395,8 @@ PD_REGISTER_KERNEL(logit_grad, ...@@ -393,7 +395,8 @@ PD_REGISTER_KERNEL(logit_grad,
phi::LogitGradKernel, phi::LogitGradKernel,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(square_grad, PD_REGISTER_KERNEL(square_grad,
GPU, GPU,
......
...@@ -212,21 +212,24 @@ PD_REGISTER_KERNEL(exp, ...@@ -212,21 +212,24 @@ PD_REGISTER_KERNEL(exp,
double, double,
int, int,
int64_t, int64_t,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(expm1, PD_REGISTER_KERNEL(expm1,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::Expm1Kernel, phi::Expm1Kernel,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(logit, PD_REGISTER_KERNEL(logit,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::LogitKernel, phi::LogitKernel,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(square, PD_REGISTER_KERNEL(square,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
......
...@@ -255,6 +255,7 @@ PD_REGISTER_KERNEL(arg_min, ...@@ -255,6 +255,7 @@ PD_REGISTER_KERNEL(arg_min,
ALL_LAYOUT, ALL_LAYOUT,
phi::ArgMinKernel, phi::ArgMinKernel,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16,
float, float,
double, double,
int32_t, int32_t,
...@@ -267,6 +268,7 @@ PD_REGISTER_KERNEL(arg_max, ...@@ -267,6 +268,7 @@ PD_REGISTER_KERNEL(arg_max,
ALL_LAYOUT, ALL_LAYOUT,
phi::ArgMaxKernel, phi::ArgMaxKernel,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16,
float, float,
double, double,
int32_t, int32_t,
......
...@@ -282,6 +282,7 @@ void CrossEntropyWithSoftmaxGradKernel(const Context& dev_ctx, ...@@ -282,6 +282,7 @@ void CrossEntropyWithSoftmaxGradKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(cross_entropy_with_softmax_grad, PD_REGISTER_KERNEL(cross_entropy_with_softmax_grad,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
...@@ -289,3 +290,23 @@ PD_REGISTER_KERNEL(cross_entropy_with_softmax_grad, ...@@ -289,3 +290,23 @@ PD_REGISTER_KERNEL(cross_entropy_with_softmax_grad,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16) {}
#else
#if CUDNN_VERSION_MIN(8, 1, 0)
PD_REGISTER_KERNEL(cross_entropy_with_softmax_grad,
GPU,
ALL_LAYOUT,
phi::CrossEntropyWithSoftmaxGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
#else
PD_REGISTER_KERNEL(cross_entropy_with_softmax_grad,
GPU,
ALL_LAYOUT,
phi::CrossEntropyWithSoftmaxGradKernel,
float,
double,
phi::dtype::float16) {}
#endif
#endif
...@@ -252,7 +252,7 @@ __device__ __forceinline__ AccT ThreadReduce(const T* input, ...@@ -252,7 +252,7 @@ __device__ __forceinline__ AccT ThreadReduce(const T* input,
input -= offset; input -= offset;
size += offset; size += offset;
if (tid >= offset) { if (tid >= offset) {
val = reducer(val, input[tid]); val = reducer(val, static_cast<AccT>(input[tid]));
} }
size -= blockDim.x; size -= blockDim.x;
input += blockDim.x; input += blockDim.x;
...@@ -268,14 +268,14 @@ __device__ __forceinline__ AccT ThreadReduce(const T* input, ...@@ -268,14 +268,14 @@ __device__ __forceinline__ AccT ThreadReduce(const T* input,
#pragma unroll #pragma unroll
for (int i = 0; i < VecSize; ++i) { for (int i = 0; i < VecSize; ++i) {
val = reducer(val, ins[i]); val = reducer(val, static_cast<AccT>(ins[i]));
} }
} }
// scalar part // scalar part
tid = size - remain + threadIdx.x; tid = size - remain + threadIdx.x;
for (; tid < size; tid += blockDim.x) { for (; tid < size; tid += blockDim.x) {
val = reducer(val, input[tid]); val = reducer(val, static_cast<AccT>(input[tid]));
} }
return val; return val;
} }
...@@ -1470,6 +1470,16 @@ PD_REGISTER_KERNEL(cross_entropy_with_softmax, ...@@ -1470,6 +1470,16 @@ PD_REGISTER_KERNEL(cross_entropy_with_softmax,
float, float,
phi::dtype::float16) {} phi::dtype::float16) {}
#else #else
#if CUDNN_VERSION_MIN(8, 1, 0)
PD_REGISTER_KERNEL(cross_entropy_with_softmax,
GPU,
ALL_LAYOUT,
phi::CrossEntropyWithSoftmaxKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
#else
PD_REGISTER_KERNEL(cross_entropy_with_softmax, PD_REGISTER_KERNEL(cross_entropy_with_softmax,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
...@@ -1478,3 +1488,4 @@ PD_REGISTER_KERNEL(cross_entropy_with_softmax, ...@@ -1478,3 +1488,4 @@ PD_REGISTER_KERNEL(cross_entropy_with_softmax,
double, double,
phi::dtype::float16) {} phi::dtype::float16) {}
#endif #endif
#endif
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/phi/kernels/gather_nd_grad_kernel.h" #include "paddle/phi/kernels/gather_nd_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/scatter.cu.h" #include "paddle/phi/kernels/funcs/scatter.cu.h"
...@@ -63,4 +64,5 @@ PD_REGISTER_KERNEL(gather_nd_grad, ...@@ -63,4 +64,5 @@ PD_REGISTER_KERNEL(gather_nd_grad,
double, double,
int64_t, int64_t,
int, int,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/phi/kernels/gather_nd_kernel.h" #include "paddle/phi/kernels/gather_nd_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/gather.cu.h" #include "paddle/phi/kernels/funcs/gather.cu.h"
#include "paddle/phi/kernels/funcs/scatter.cu.h" #include "paddle/phi/kernels/funcs/scatter.cu.h"
...@@ -58,4 +59,5 @@ PD_REGISTER_KERNEL(gather_nd, ...@@ -58,4 +59,5 @@ PD_REGISTER_KERNEL(gather_nd,
int, int,
int16_t, int16_t,
bool, bool,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -134,6 +134,8 @@ PD_REGISTER_KERNEL(index_sample_grad, ...@@ -134,6 +134,8 @@ PD_REGISTER_KERNEL(index_sample_grad,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::IndexSampleGradKernel, phi::IndexSampleGradKernel,
phi::dtype::float16,
phi::dtype::bfloat16,
float, float,
double, double,
int, int,
......
...@@ -107,6 +107,8 @@ PD_REGISTER_KERNEL(index_sample, ...@@ -107,6 +107,8 @@ PD_REGISTER_KERNEL(index_sample,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::IndexSampleKernel, phi::IndexSampleKernel,
phi::dtype::float16,
phi::dtype::bfloat16,
float, float,
double, double,
int, int,
......
...@@ -25,4 +25,5 @@ PD_REGISTER_KERNEL(tril_triu_grad, ...@@ -25,4 +25,5 @@ PD_REGISTER_KERNEL(tril_triu_grad,
double, double,
int, int,
int64_t, int64_t,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -25,4 +25,5 @@ PD_REGISTER_KERNEL(tril_triu, ...@@ -25,4 +25,5 @@ PD_REGISTER_KERNEL(tril_triu,
double, double,
int, int,
int64_t, int64_t,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -114,7 +114,8 @@ PD_REGISTER_KERNEL(less_than, ...@@ -114,7 +114,8 @@ PD_REGISTER_KERNEL(less_than,
int64_t, int64_t,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(less_equal, PD_REGISTER_KERNEL(less_equal,
KPS, KPS,
ALL_LAYOUT, ALL_LAYOUT,
...@@ -125,7 +126,8 @@ PD_REGISTER_KERNEL(less_equal, ...@@ -125,7 +126,8 @@ PD_REGISTER_KERNEL(less_equal,
int64_t, int64_t,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(greater_than, PD_REGISTER_KERNEL(greater_than,
KPS, KPS,
ALL_LAYOUT, ALL_LAYOUT,
...@@ -136,7 +138,8 @@ PD_REGISTER_KERNEL(greater_than, ...@@ -136,7 +138,8 @@ PD_REGISTER_KERNEL(greater_than,
int64_t, int64_t,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(greater_equal, PD_REGISTER_KERNEL(greater_equal,
KPS, KPS,
ALL_LAYOUT, ALL_LAYOUT,
...@@ -147,7 +150,8 @@ PD_REGISTER_KERNEL(greater_equal, ...@@ -147,7 +150,8 @@ PD_REGISTER_KERNEL(greater_equal,
int64_t, int64_t,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(equal, PD_REGISTER_KERNEL(equal,
KPS, KPS,
ALL_LAYOUT, ALL_LAYOUT,
...@@ -158,7 +162,8 @@ PD_REGISTER_KERNEL(equal, ...@@ -158,7 +162,8 @@ PD_REGISTER_KERNEL(equal,
int64_t, int64_t,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(not_equal, PD_REGISTER_KERNEL(not_equal,
KPS, KPS,
ALL_LAYOUT, ALL_LAYOUT,
...@@ -169,7 +174,8 @@ PD_REGISTER_KERNEL(not_equal, ...@@ -169,7 +174,8 @@ PD_REGISTER_KERNEL(not_equal,
int64_t, int64_t,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(equal_all, PD_REGISTER_KERNEL(equal_all,
KPS, KPS,
......
...@@ -63,7 +63,8 @@ PD_REGISTER_KERNEL(shape, ...@@ -63,7 +63,8 @@ PD_REGISTER_KERNEL(shape,
double, double,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>, phi::dtype::complex<double>,
phi::dtype::float16) { phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
} }
#endif #endif
......
...@@ -3791,8 +3791,17 @@ def gather_nd(x, index, name=None): ...@@ -3791,8 +3791,17 @@ def gather_nd(x, index, name=None):
check_variable_and_dtype( check_variable_and_dtype(
x, x,
'x', 'x',
['bool', 'float32', 'float64', 'int16', 'int32', 'int64'], [
'gather_np', 'bool',
'float16',
'uint16',
'float32',
'float64',
'int16',
'int32',
'int64',
],
'gather_nd',
) )
check_variable_and_dtype(index, 'index', ['int32', 'int64'], 'gather_np') check_variable_and_dtype(index, 'index', ['int32', 'int64'], 'gather_np')
helper = LayerHelper('gather_nd', **locals()) helper = LayerHelper('gather_nd', **locals())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册