未验证 提交 ab00c96c 编写于 作者: L lijin23 提交者: GitHub

[XPU][PHI Kernels] bind reduce_max_int64 set_value_bool sin_grad_fp32...

[XPU][PHI Kernels] bind reduce_max_int64 set_value_bool sin_grad_fp32 cos_grad_fp32 for XPU (#55375)

* bind kernels for xpu

* format code

* format code

* 0d support for set value

* refine set_value
上级 bc61c796
......@@ -591,7 +591,9 @@ XPUOpMap& get_kl2_ops() {
{"reduce_any", XPUKernelSet({phi::DataType::BOOL})},
{"reduce_max_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_max",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32})},
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32,
phi::DataType::INT64})},
{"reduce_mean_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_mean",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
......@@ -663,7 +665,8 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32,
phi::DataType::INT64,
phi::DataType::FLOAT16})},
phi::DataType::FLOAT16,
phi::DataType::BOOL})},
{"set_value_grad",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32,
......@@ -935,7 +938,9 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::FLOAT32,
phi::DataType::FLOAT16})},
{"sin", XPUKernelSet({phi::DataType::FLOAT32})},
{"sin_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"cos", XPUKernelSet({phi::DataType::FLOAT32})},
{"cos_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"linspace",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32})},
{"randint", XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64})},
......
......@@ -565,6 +565,52 @@ struct XPUSoftPlusGradFunctor : public funcs::BaseActivationFunctor<T> {
}
};
template <typename T>
struct XPUSinGradFunctor : public funcs::BaseActivationFunctor<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
template <typename Context>
void operator()(const Context& dev_ctx,
const DenseTensor* x,
const DenseTensor* out,
const DenseTensor* dout,
DenseTensor* dx) const {
int64_t len = dx->numel();
auto dx_data = dev_ctx.template Alloc<T>(dx);
auto dout_data = dout->data<T>();
auto x_data = x->data<T>();
int r = xpu::sin_grad<T>(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x_data),
reinterpret_cast<const XPUType*>(dout_data),
reinterpret_cast<XPUType*>(dx_data),
len);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "sin_grad");
}
};
template <typename T>
struct XPUCosGradFunctor : public funcs::BaseActivationFunctor<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
template <typename Context>
void operator()(const Context& dev_ctx,
const DenseTensor* x,
const DenseTensor* out,
const DenseTensor* dout,
DenseTensor* dx) const {
int64_t len = dx->numel();
auto dx_data = dev_ctx.template Alloc<T>(dx);
auto dout_data = dout->data<T>();
auto x_data = x->data<T>();
int r = xpu::cos_grad<T>(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x_data),
reinterpret_cast<const XPUType*>(dout_data),
reinterpret_cast<XPUType*>(dx_data),
len);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cos_grad");
}
};
DEFINE_XPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Exp, XPUExpGradFunctor);
DEFINE_XPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Reciprocal, XPUReciprocalGradFunctor);
DEFINE_XPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Sigmoid, XPUSigmoidGradFunctor);
......@@ -576,6 +622,8 @@ DEFINE_XPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Relu6, XPURelu6GradFunctor);
DEFINE_XPU_ACTIVATION_GRAD_KERNEL_DEPX(Log, XPULogGradFunctor);
DEFINE_XPU_ACTIVATION_GRAD_KERNEL_DEPX(Square, XPUSquareGradFunctor);
DEFINE_XPU_ACTIVATION_GRAD_KERNEL_DEPX(Swish, XPUSwishGradFunctor);
DEFINE_XPU_ACTIVATION_GRAD_KERNEL_DEPX(Sin, XPUSinGradFunctor);
DEFINE_XPU_ACTIVATION_GRAD_KERNEL_DEPX(Cos, XPUCosGradFunctor);
DEFINE_XPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Mish,
XPUMishGradFunctor,
......@@ -664,4 +712,7 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_grad, SoftplusGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sin_grad, SinGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(cos_grad, CosGradKernel)
PD_REGISTER_KERNEL(pow_grad, XPU, ALL_LAYOUT, phi::PowGradKernel, float) {}
......@@ -48,4 +48,4 @@ void MaxKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(max, XPU, ALL_LAYOUT, phi::MaxKernel, float, int) {}
PD_REGISTER_KERNEL(max, XPU, ALL_LAYOUT, phi::MaxKernel, float, int, int64_t) {}
......@@ -14,16 +14,12 @@
#include "paddle/phi/kernels/set_value_kernel.h"
#include <algorithm>
#include <vector>
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/slice_utils.h"
#include "paddle/phi/kernels/xpu/elementwise.h"
......@@ -84,6 +80,14 @@ void SetValueImpl(const Context& dev_ctx,
DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
auto in_dims = in.dims();
auto new_value_dims = value_dims;
// support for 0-d tensor
if (value_dims.size() == 0) {
new_value_dims = {1};
}
std::vector<int64_t> starts_local = starts.GetData();
std::vector<int64_t> ends_local = ends.GetData();
std::vector<int64_t> steps_local = steps.GetData();
......@@ -165,41 +169,7 @@ void SetValueImpl(const Context& dev_ctx,
return;
}
}
// Because strided_slice does not support the case of stride < 0
// temporarily, the coordinates of starts_indices, ends_indices
// and strides_indices need to be converted.
// This logic may be deleted in the future.
bool need_flip = false;
for (size_t i = 0; i < RANK; ++i) {
if (strides_indices[i] < 0) {
if (!need_flip) {
need_flip = true;
}
flip_axis.push_back(i);
strides_indices[i] = strides_indices[i] * (-1);
ends_indices[i] = starts_indices[i] + 1;
starts_indices[i] =
starts_indices[i] - (slice_dims[i] - 1) * strides_indices[i];
}
}
auto out_shape = phi::vectorize<int>(out->dims());
auto slice_shape = phi::vectorize<int>(slice_dims);
r = xpu::strided_slice(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(out->data<T>()),
slice_data,
out_shape,
starts_indices,
ends_indices,
strides_indices);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "strided_slice");
r = xpu::constant(dev_ctx.x_context(), slice_data, slice_numels, XPUType(0));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
// Step 2: Set a tensor with the same shape as out tensor. And its data at
// '_index' is the same as value, and data out of '_index' to zero
// Step 2: Set slice tensor
// - Step 2.1 Set slice tensor with value
......@@ -215,26 +185,50 @@ void SetValueImpl(const Context& dev_ctx,
// If do broadcasting on Tensor with shape [3] and [3], the result's shape
// is [3], which is right.
CheckIsDimsMatch(slice_dims_for_assign, value_dims);
// XPUElementwise can do broadcasting
CheckIsDimsMatch(slice_dims_for_assign, new_value_dims);
// do broadcasting
auto f = [](xpu::Context* ctx,
const XPUType* x,
const XPUType* y,
const XPUType* y, /*unused*/
XPUType* z,
const std::vector<int>& xshape,
const std::vector<int>& yshape) {
return xpu::broadcast_add<XPUType>(ctx, x, y, z, xshape, yshape);
const std::vector<int>& zshape) {
return xpu::broadcast<XPUType>(ctx, x, z, xshape, zshape);
};
XPUElementwise<T, XPUType>(dev_ctx,
reinterpret_cast<const T*>(slice_data),
slice_dims_for_assign,
value_data,
value_dims,
new_value_dims,
nullptr,
slice_dims_for_assign,
-1,
reinterpret_cast<T*>(slice_data),
f);
// - Step 2.2 If stride < 0, flip the slice_tensor.
// Because strided_slice_view_update does not support the case of stride < 0
// temporarily, the coordinates of starts_indices, ends_indices
// and strides_indices need to be converted.
// This logic may be deleted in the future.
bool need_flip = false;
for (size_t i = 0; i < RANK; ++i) {
if (strides_indices[i] < 0) {
if (!need_flip) {
need_flip = true;
}
flip_axis.push_back(i);
strides_indices[i] = strides_indices[i] * (-1);
ends_indices[i] = starts_indices[i] + 1;
starts_indices[i] =
starts_indices[i] - (slice_dims[i] - 1) * strides_indices[i];
}
}
auto out_shape = phi::vectorize<int>(out->dims());
auto slice_shape = phi::vectorize<int>(slice_dims);
if (need_flip) {
r = xpu::flip(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(slice_data),
......@@ -392,29 +386,20 @@ void SetValueKernel(const Context& dev_ctx,
const std::vector<int64_t>& shape,
const std::vector<Scalar>& values,
DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
std::vector<T> assign_values;
assign_values.reserve(values.size());
for (const auto& val : values) {
assign_values.push_back(val.to<T>());
}
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
auto value_dims = phi::make_ddim(shape);
XPUType* value_data =
RAII_GUARD.alloc_l3_or_gm<XPUType>(phi::product(value_dims));
phi::CPUPlace src_place;
auto dst_place = dev_ctx.GetPlace();
memory_utils::Copy(dst_place,
value_data,
src_place,
assign_values.data(),
assign_values.size() * sizeof(T));
DenseTensor value_tensor;
TensorFromVector<T>(assign_values, dev_ctx, &value_tensor);
SetValueKernelImpl<T, Context>(dev_ctx,
x,
reinterpret_cast<const T*>(value_data),
value_tensor.data<T>(),
value_dims,
starts,
ends,
......@@ -434,7 +419,8 @@ PD_REGISTER_KERNEL(set_value,
float,
phi::dtype::float16,
int,
int64_t) {}
int64_t,
bool) {}
PD_REGISTER_KERNEL(set_value_with_tensor,
XPU,
......
......@@ -57,6 +57,8 @@ class XPUTestSetValueOp(XPUOpTestWrapper):
def set_dtype(self):
self.dtype = self.in_type
if self.in_type == np.bool_:
self.dtype = "bool"
def _call_setitem(self, x):
x[0, 0] = self.value
......@@ -215,6 +217,8 @@ class XPUTestSetValueOp(XPUOpTestWrapper):
def set_dtype(self):
if self.in_type == np.float16:
self.dtype = "float32"
elif self.in_type == np.bool_:
self.dtype = "bool"
else:
self.dtype = self.in_type
......@@ -308,6 +312,8 @@ class XPUTestSetValueOp(XPUOpTestWrapper):
def set_dtype(self):
if self.in_type == np.float16:
self.dtype = "float32"
elif self.in_type == np.bool_:
self.dtype = "bool"
else:
self.dtype = self.in_type
......@@ -327,13 +333,9 @@ class XPUTestSetValueOp(XPUOpTestWrapper):
def _get_answer(self):
self.data[5:2:-1] = self.value
class XPUTestSetValueItemSliceNegetiveStep2(XPUTestSetValueApi):
def set_dtype(self):
if self.in_type == np.float16:
self.dtype = "float32"
else:
self.dtype = self.in_type
class XPUTestSetValueItemSliceNegetiveStep2(
XPUTestSetValueItemSliceNegetiveStep
):
def set_shape(self):
self.shape = [5]
......@@ -351,13 +353,9 @@ class XPUTestSetValueOp(XPUOpTestWrapper):
def _get_answer(self):
self.data[1::-1] = self.value
class XPUTestSetValueItemSliceNegetiveStep3(XPUTestSetValueApi):
def set_dtype(self):
if self.in_type == np.float16:
self.dtype = "float32"
else:
self.dtype = self.in_type
class XPUTestSetValueItemSliceNegetiveStep3(
XPUTestSetValueItemSliceNegetiveStep
):
def set_shape(self):
self.shape = [3]
......@@ -375,6 +373,14 @@ class XPUTestSetValueOp(XPUOpTestWrapper):
self.data[::-1] = self.value
class XPUTestSetValueItemSliceNegetiveStep4(XPUTestSetValueApi):
def set_dtype(self):
if self.in_type == np.float16:
self.dtype = "float32"
elif self.in_type == np.bool_:
self.dtype = "bool"
else:
self.dtype = self.in_type
def set_shape(self):
self.shape = [3, 4, 5]
......@@ -395,6 +401,14 @@ class XPUTestSetValueOp(XPUOpTestWrapper):
# 1.2.3 step < 0 and stride < -1
class XPUTestSetValueItemSliceNegetiveStep5(XPUTestSetValueApi):
def set_dtype(self):
if self.in_type == np.float16:
self.dtype = "float32"
elif self.in_type == np.bool_:
self.dtype = "bool"
else:
self.dtype = self.in_type
def set_shape(self):
self.shape = [5, 5, 5]
......@@ -484,6 +498,8 @@ class XPUTestSetValueOp(XPUOpTestWrapper):
def set_dtype(self):
if self.in_type == np.float16:
self.dtype = "float32"
elif self.in_type == np.bool_:
self.dtype = "bool"
else:
self.dtype = self.in_type
......@@ -499,13 +515,7 @@ class XPUTestSetValueOp(XPUOpTestWrapper):
def _get_answer(self):
self.data[0] = self.value
class XPUTestSetValueItemTensor2(XPUTestSetValueApi):
def set_dtype(self):
if self.in_type == np.float16:
self.dtype = "float32"
else:
self.dtype = self.in_type
class XPUTestSetValueItemTensor2(XPUTestSetValueItemTensor):
def _call_setitem(self, x):
zero = paddle.full([1], 0, dtype="int32")
two = paddle.full([1], 2, dtype="int64")
......@@ -520,13 +530,7 @@ class XPUTestSetValueOp(XPUOpTestWrapper):
def _get_answer(self):
self.data[0:2] = self.value
class XPUTestSetValueItemTensor3(XPUTestSetValueApi):
def set_dtype(self):
if self.in_type == np.float16:
self.dtype = "float32"
else:
self.dtype = self.in_type
class XPUTestSetValueItemTensor3(XPUTestSetValueItemTensor):
def _call_setitem(self, x):
zero = paddle.full([1], 0, dtype="int32")
two = paddle.full([1], 2, dtype="int64")
......@@ -543,13 +547,7 @@ class XPUTestSetValueOp(XPUOpTestWrapper):
def _get_answer(self):
self.data[0:-1, 0:2] = self.value
class XPUTestSetValueItemTensor4(XPUTestSetValueApi):
def set_dtype(self):
if self.in_type == np.float16:
self.dtype = "float32"
else:
self.dtype = self.in_type
class XPUTestSetValueItemTensor4(XPUTestSetValueItemTensor):
def _call_setitem(self, x):
zero = paddle.full([1], 0, dtype="int32")
two = paddle.full([1], 2, dtype="int64")
......@@ -566,13 +564,7 @@ class XPUTestSetValueOp(XPUOpTestWrapper):
def _get_answer(self):
self.data[0:-1, 0:2, ::2] = self.value
class XPUTestSetValueItemTensor5(XPUTestSetValueApi):
def set_dtype(self):
if self.in_type == np.float16:
self.dtype = "float32"
else:
self.dtype = self.in_type
class XPUTestSetValueItemTensor5(XPUTestSetValueItemTensor):
def _call_setitem(self, x):
zero = paddle.full([1], 0, dtype="int32")
two = paddle.full([1], 2, dtype="int64")
......@@ -591,13 +583,7 @@ class XPUTestSetValueOp(XPUOpTestWrapper):
def _get_answer(self):
self.data[0:, 1:2:2, :] = self.value
class XPUTestSetValueItemTensor6(XPUTestSetValueApi):
def set_dtype(self):
if self.in_type == np.float16:
self.dtype = "float32"
else:
self.dtype = self.in_type
class XPUTestSetValueItemTensor6(XPUTestSetValueItemTensor):
def set_shape(self):
self.shape = [3, 4, 5]
......@@ -624,6 +610,8 @@ class XPUTestSetValueOp(XPUOpTestWrapper):
def set_dtype(self):
if self.in_type == np.float16:
self.dtype = "float32"
elif self.in_type == np.bool_:
self.dtype = "bool"
else:
self.dtype = self.in_type
......@@ -637,13 +625,7 @@ class XPUTestSetValueOp(XPUOpTestWrapper):
def _get_answer(self):
self.data[None] = self.value
class XPUTestSetValueItemNone2(XPUTestSetValueApi):
def set_dtype(self):
if self.in_type == np.float16:
self.dtype = "float32"
else:
self.dtype = self.in_type
class XPUTestSetValueItemNone2(XPUTestSetValueItemNone1):
def _call_setitem(self, x):
x[0, None, 1] = self.value
......@@ -654,13 +636,7 @@ class XPUTestSetValueOp(XPUOpTestWrapper):
def _get_answer(self):
self.data[0, None, 1] = self.value
class XPUTestSetValueItemNone3(XPUTestSetValueApi):
def set_dtype(self):
if self.in_type == np.float16:
self.dtype = "float32"
else:
self.dtype = self.in_type
class XPUTestSetValueItemNone3(XPUTestSetValueItemNone1):
def _call_setitem(self, x):
x[:, None, None, 1] = self.value
......@@ -673,13 +649,7 @@ class XPUTestSetValueOp(XPUOpTestWrapper):
def _get_answer(self):
self.data[:, None, None, 1] = self.value
class XPUTestSetValueItemNone4(XPUTestSetValueApi):
def set_dtype(self):
if self.in_type == np.float16:
self.dtype = "float32"
else:
self.dtype = self.in_type
class XPUTestSetValueItemNone4(XPUTestSetValueItemNone1):
def _call_setitem(self, x):
x[0, 0, None, 1] = self.value
......@@ -690,13 +660,7 @@ class XPUTestSetValueOp(XPUOpTestWrapper):
def _get_answer(self):
self.data[0, 0, None, 1] = self.value
class XPUTestSetValueItemNone5(XPUTestSetValueApi):
def set_dtype(self):
if self.in_type == np.float16:
self.dtype = "float32"
else:
self.dtype = self.in_type
class XPUTestSetValueItemNone5(XPUTestSetValueItemNone1):
def _call_setitem(self, x):
x[0, None, 0, None, 1] = self.value
......@@ -707,13 +671,7 @@ class XPUTestSetValueOp(XPUOpTestWrapper):
def _get_answer(self):
self.data[0, None, 0, None, 1] = self.value
class XPUTestSetValueItemNone6(XPUTestSetValueApi):
def set_dtype(self):
if self.in_type == np.float16:
self.dtype = "float32"
else:
self.dtype = self.in_type
class XPUTestSetValueItemNone6(XPUTestSetValueItemNone1):
def _call_setitem(self, x):
x[None, 0, 0, None, 0] = self.value
......@@ -724,13 +682,7 @@ class XPUTestSetValueOp(XPUOpTestWrapper):
def _get_answer(self):
self.data[None, 0, 0, None, 0] = self.value
class XPUTestSetValueItemNone7(XPUTestSetValueApi):
def set_dtype(self):
if self.in_type == np.float16:
self.dtype = "float32"
else:
self.dtype = self.in_type
class XPUTestSetValueItemNone7(XPUTestSetValueItemNone1):
def _call_setitem(self, x):
x[:, None, 1] = np.zeros(self.shape)[:, None, 0]
......@@ -745,13 +697,7 @@ class XPUTestSetValueOp(XPUOpTestWrapper):
def _get_answer(self):
self.data[:, None, 1] = np.zeros(self.shape)[:, None, 0]
class XPUTestSetValueItemNone8(XPUTestSetValueApi):
def set_dtype(self):
if self.in_type == np.float16:
self.dtype = "float32"
else:
self.dtype = self.in_type
class XPUTestSetValueItemNone8(XPUTestSetValueItemNone1):
def _call_setitem(self, x):
x[:, 1, None] = np.zeros(self.shape)[:, 0, None]
......@@ -766,13 +712,7 @@ class XPUTestSetValueOp(XPUOpTestWrapper):
def _get_answer(self):
self.data[:, 1, None] = np.zeros(self.shape)[:, 0, None]
class XPUTestSetValueItemNone9(XPUTestSetValueApi):
def set_dtype(self):
if self.in_type == np.float16:
self.dtype = "float32"
else:
self.dtype = self.in_type
class XPUTestSetValueItemNone9(XPUTestSetValueItemNone1):
def _call_setitem(self, x):
x[None, :, 1, ..., None] = np.zeros(self.shape)[0, 0, :, None]
......@@ -789,13 +729,7 @@ class XPUTestSetValueOp(XPUOpTestWrapper):
0, 0, :, None
]
class XPUTestSetValueItemNone10(XPUTestSetValueApi):
def set_dtype(self):
if self.in_type == np.float16:
self.dtype = "float32"
else:
self.dtype = self.in_type
class XPUTestSetValueItemNone10(XPUTestSetValueItemNone1):
def _call_setitem(self, x):
x[..., None, :, None] = np.zeros(self.shape)[..., None, :, None]
......@@ -815,10 +749,7 @@ class XPUTestSetValueOp(XPUOpTestWrapper):
# 1.6 item is list or Tensor of bol
class XPUTestSetValueItemBool1(XPUTestSetValueApi):
def set_dtype(self):
if self.in_type == np.float16:
self.dtype = "float32"
else:
self.dtype = self.in_type
self.dtype = "float32"
def _call_setitem(self, x):
x[[True, False]] = self.value
......@@ -832,10 +763,7 @@ class XPUTestSetValueOp(XPUOpTestWrapper):
class XPUTestSetValueItemBool2(XPUTestSetValueApi):
def set_dtype(self):
if self.in_type == np.float16:
self.dtype = "float32"
else:
self.dtype = self.in_type
self.dtype = "float32"
def _call_setitem(self, x):
x[[False, False]] = self.value
......@@ -849,10 +777,7 @@ class XPUTestSetValueOp(XPUOpTestWrapper):
class XPUTestSetValueItemBool3(XPUTestSetValueApi):
def set_dtype(self):
if self.in_type == np.float16:
self.dtype = "float32"
else:
self.dtype = self.in_type
self.dtype = "float32"
def _call_setitem(self, x):
x[[False, True]] = np.zeros(self.shape[2])
......@@ -866,10 +791,7 @@ class XPUTestSetValueOp(XPUOpTestWrapper):
class XPUTestSetValueItemBool4(XPUTestSetValueApi):
def set_dtype(self):
if self.in_type == np.float16:
self.dtype = "float32"
else:
self.dtype = self.in_type
self.dtype = "float32"
def _call_setitem(self, x):
idx = paddle.assign(np.array([False, True]))
......@@ -885,10 +807,7 @@ class XPUTestSetValueOp(XPUOpTestWrapper):
class XPUTestSetValueItemBool5(XPUTestSetValueApi):
def set_dtype(self):
if self.in_type == np.float16:
self.dtype = "float32"
else:
self.dtype = self.in_type
self.dtype = "float32"
def _call_setitem(self, x):
idx = paddle.assign(
......@@ -910,10 +829,7 @@ class XPUTestSetValueOp(XPUOpTestWrapper):
class XPUTestSetValueItemBool6(XPUTestSetValueApi):
def set_dtype(self):
if self.in_type == np.float16:
self.dtype = "float32"
else:
self.dtype = self.in_type
self.dtype = "float32"
def _call_setitem(self, x):
x[0, ...] = 0
......@@ -1047,6 +963,8 @@ class XPUTestSetValueOp(XPUOpTestWrapper):
def set_dtype(self):
if self.in_type == np.float16:
self.dtype = "float32"
elif self.in_type == np.bool_:
self.dtype = "bool"
else:
self.dtype = self.in_type
......@@ -1063,13 +981,7 @@ class XPUTestSetValueOp(XPUOpTestWrapper):
def _get_answer(self):
self.data[0] = self.value
class XPUTestSetValueValueShape2(XPUTestSetValueApi):
def set_dtype(self):
if self.in_type == np.float16:
self.dtype = "float32"
else:
self.dtype = self.in_type
class XPUTestSetValueValueShape2(XPUTestSetValueValueShape1):
def set_value(self):
self.value = np.array([[3, 4, 5, 6]]) # shape is (1,4)
......@@ -1083,13 +995,7 @@ class XPUTestSetValueOp(XPUOpTestWrapper):
def _get_answer(self):
self.data[0:1] = self.value
class XPUTestSetValueValueShape3(XPUTestSetValueApi):
def set_dtype(self):
if self.in_type == np.float16:
self.dtype = "float32"
else:
self.dtype = self.in_type
class XPUTestSetValueValueShape3(XPUTestSetValueValueShape1):
def set_value(self):
self.value = np.array(
[[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3]]
......@@ -1105,13 +1011,7 @@ class XPUTestSetValueOp(XPUOpTestWrapper):
def _get_answer(self):
self.data[0] = self.value
class XPUTestSetValueValueShape4(XPUTestSetValueApi):
def set_dtype(self):
if self.in_type == np.float16:
self.dtype = "float32"
else:
self.dtype = self.in_type
class XPUTestSetValueValueShape4(XPUTestSetValueValueShape1):
def set_value(self):
self.value = np.array(
[[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3]]
......@@ -1129,13 +1029,7 @@ class XPUTestSetValueOp(XPUOpTestWrapper):
def _get_answer(self):
self.data[0] = self.value
class XPUTestSetValueValueShape5(XPUTestSetValueApi):
def set_dtype(self):
if self.in_type == np.float16:
self.dtype = "float32"
else:
self.dtype = self.in_type
class XPUTestSetValueValueShape5(XPUTestSetValueValueShape1):
def set_value(self):
self.value = np.array([3, 3, 3]).astype(self.dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册