未验证 提交 e23dfed9 编写于 作者: Y ykkk2333 提交者: GitHub

Fix paddle rec, kim, dsin models' bugs (#47792)

* add stat tool

* add roll and roll_grad kernels and strided_slice and strided_slice_grad kernels, test=kunlun

* embedding and embedding_grad add int32 input, test=kunlun
上级 a762d68e
......@@ -320,11 +320,75 @@ struct SelectedRowsAddToTensor<phi::CPUContext, T> {
}
};
#ifdef PADDLE_WITH_XPU
template <typename T>
struct SelectedRowsAddToTensor<phi::XPUContext, T> {
void operator()(const phi::XPUContext& context,
const phi::SelectedRows& input1,
phi::DenseTensor* input2) {
if (UNLIKELY(input1.rows().size() == 0)) {
LOG(WARNING) << "input selected rows is empty!";
return;
}
using XPUType = typename XPUTypeTrait<T>::Type;
auto in1_height = input1.height();
const auto& in2_dims = input2->dims();
PADDLE_ENFORCE_EQ(
in1_height,
in2_dims[0],
phi::errors::InvalidArgument("The two inputs height must be equal."
"But received first input height = "
"[%d], second input height = [%d]",
in1_height,
in2_dims[0]));
auto& in1_value = input1.value();
auto& in1_rows = input1.rows();
int64_t* in1_rows_data = nullptr;
xpu::VectorParam<int64_t> in1_rows_vec{
in1_rows.data(), static_cast<int>(in1_rows.size()), in1_rows_data};
int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
PADDLE_ENFORCE_EQ(
in1_row_numel,
input2->numel() / in1_height,
phi::errors::InvalidArgument(
"The two inputs width must be equal."
"But received first input width = [%d], second input width = [%d]",
in1_row_numel,
input2->numel() / in1_height));
auto* in1_data = in1_value.data<T>();
auto* out_data = input2->data<T>();
int h = in1_rows.size();
int w = in1_row_numel;
const std::vector<int> xshape{h, w};
int r = xpu::scatter<XPUType, int64_t>(
context.x_context(),
nullptr,
reinterpret_cast<const XPUType*>(in1_data),
reinterpret_cast<XPUType*>(out_data),
in1_rows_vec,
xshape,
0,
false);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "scatter");
}
};
#endif
template struct SelectedRowsAddToTensor<phi::CPUContext, float>;
template struct SelectedRowsAddToTensor<phi::CPUContext, double>;
template struct SelectedRowsAddToTensor<phi::CPUContext, int>;
template struct SelectedRowsAddToTensor<phi::CPUContext, int64_t>;
template struct SelectedRowsAddToTensor<phi::CPUContext, phi::dtype::bfloat16>;
#ifdef PADDLE_WITH_XPU
template struct SelectedRowsAddToTensor<phi::XPUContext, float>;
#endif
// This is a separated namespace for manipulate SelectedRows typed
// data. Like merge duplicated rows, adding two SelectedRows etc.
//
......
......@@ -17,6 +17,8 @@
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/selected_rows_functor.h"
namespace phi {
template <typename T, typename Context>
......@@ -25,6 +27,8 @@ void AddNKernel(const Context& dev_ctx,
DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
size_t in_num = x.size();
dev_ctx.template Alloc<T>(out);
bool in_place = false;
if (x.size() > 0 && x[0]->initialized() && DenseTensor::classof(x[0])) {
if ((static_cast<const DenseTensor*>(x[0]))->Holder() == out->Holder()) {
......@@ -33,26 +37,61 @@ void AddNKernel(const Context& dev_ctx,
}
if (!in_place) {
dev_ctx.template Alloc<T>(out);
int r = xpu::constant(dev_ctx.x_context(),
reinterpret_cast<XPUType*>(out->data<T>()),
out->numel(),
XPUType(0));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
}
std::vector<const XPUType*> ptrs;
phi::funcs::SelectedRowsAddToTensor<Context, float> functor;
for (size_t i = 0; i < in_num; ++i) {
PADDLE_ENFORCE_EQ(DenseTensor::classof(x[i]),
true,
errors::InvalidArgument("XPU only support DensorTensor"));
if (DenseTensor::classof(x[i])) {
auto& in_t = *(static_cast<const DenseTensor*>(x[i]));
if (!in_t.initialized() || in_t.numel() == 0) {
continue;
}
ptrs.push_back(reinterpret_cast<const XPUType*>(in_t.data<T>()));
} else if (SelectedRows::classof(x[i])) {
PADDLE_ENFORCE_EQ(x[i]->dtype(),
DataType::FLOAT32,
errors::InvalidArgument("SelectedRowsAdd(scatter) only",
"supports float type"));
auto& in_t = *(static_cast<const DenseTensor*>(x[i]));
if (in_t.numel() == 0) {
continue;
auto& in_t = *(static_cast<const SelectedRows*>(x[i]));
functor(dev_ctx, in_t, out);
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"Expected type of Input(X) of %d-th must be Tensor, "
"SelectedRows. But got "
"unsupport type: %s.",
x[i]->type_info().name()));
}
ptrs.push_back(reinterpret_cast<const XPUType*>(in_t.data<T>()));
}
int r = xpu::sum(dev_ctx.x_context(),
ptrs,
reinterpret_cast<XPUType*>(out->data<T>()),
out->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "sum");
if (ptrs.empty()) {
return;
} else if (ptrs.size() < x.size()) {
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
XPUType* out_t = RAII_GUARD.alloc_l3_or_gm<XPUType>(out->numel());
int r = xpu::sum(dev_ctx.x_context(), ptrs, out_t, out->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "sum");
r = xpu::add(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(out->data<T>()),
out_t,
reinterpret_cast<XPUType*>(out->data<T>()),
out->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "add");
} else {
int r = xpu::sum(dev_ctx.x_context(),
ptrs,
reinterpret_cast<XPUType*>(out->data<T>()),
out->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "sum");
}
}
template <typename T, typename Context>
......
......@@ -43,7 +43,18 @@ void EmbeddingGradKernel(const Context& ctx,
"number of ids in LookupTableV2GradXPUKernel."));
auto& dev_ctx = ctx;
const int64_t* ids_data = ids_t->data<int64_t>();
xpu::ctx_guard RAII_GUARD(ctx.x_context());
const int64_t* ids_data;
if (ids_t->dtype() == phi::DataType::INT64) {
ids_data = ids_t->data<int64_t>();
} else {
int64_t* ids_tt = RAII_GUARD.alloc_l3_or_gm<int64_t>(ids_t->numel());
int r = xpu::cast<int32_t, int64_t>(
ctx.x_context(), ids_t->data<int>(), ids_tt, ids_t->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
ids_data = reinterpret_cast<const int64_t*>(ids_tt);
}
const T* d_output_data = d_output_t->data<T>();
T* d_table_data = dev_ctx.template Alloc<T>(d_table_t);
int xm = d_table_t->dims()[0];
......
......@@ -42,7 +42,17 @@ void EmbeddingKernel(const Context &ctx,
auto *table = table_t->data<T>();
auto *output = dev_ctx.template Alloc<T>(output_t);
const int64_t *ids = ids_t->data<int64_t>();
xpu::ctx_guard RAII_GUARD(ctx.x_context());
const int64_t *ids;
if (ids_t->dtype() == phi::DataType::INT64) {
ids = ids_t->data<int64_t>();
} else {
int64_t *ids_tt = RAII_GUARD.alloc_l3_or_gm<int64_t>(ids_t->numel());
int r = xpu::cast<int32_t, int64_t>(
ctx.x_context(), ids_t->data<int>(), ids_tt, ids_t->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
ids = reinterpret_cast<const int64_t *>(ids_tt);
}
PADDLE_ENFORCE_EQ(
ids_numel <= std::numeric_limits<int32_t>::max(),
......
......@@ -197,6 +197,91 @@ class TestSumOpError(unittest.TestCase):
self.assertRaises(Exception, test_list_of_none_input)
class TestLoDTensorAndSelectedRowsOp(unittest.TestCase):
def setUp(self):
self.height = 10
self.row_numel = 12
self.rows = [0, 1, 2, 3, 4, 5, 6]
self.dtype = np.float32
self.init_kernel_type()
def check_with_place(self, place, inplace):
self.check_input_and_optput(place, inplace, True, True, True)
def init_kernel_type(self):
pass
def _get_array(self, rows, row_numel):
array = np.ones((len(rows), row_numel)).astype(self.dtype)
for i in range(len(rows)):
array[i] *= rows[i]
return array
def check_input_and_optput(
self,
place,
inplace,
w1_has_data=False,
w2_has_data=False,
w3_has_data=False,
):
paddle.disable_static()
w1 = self.create_lod_tensor(place)
w2 = self.create_selected_rows(place, w2_has_data)
x = [w1, w2]
out = paddle.add_n(x)
result = np.ones((1, self.height)).astype(np.int32).tolist()[0]
for ele in self.rows:
result[ele] += 1
out_t = np.array(out)
self.assertEqual(out_t.shape[0], self.height)
np.testing.assert_array_equal(
out_t,
self._get_array([i for i in range(self.height)], self.row_numel)
* np.tile(np.array(result).reshape(self.height, 1), self.row_numel),
)
paddle.enable_static()
def create_selected_rows(self, place, has_data):
# create and initialize W Variable
if has_data:
rows = self.rows
else:
rows = []
w_array = self._get_array(self.rows, self.row_numel)
var = core.eager.Tensor(
core.VarDesc.VarType.FP32,
w_array.shape,
"selected_rows",
core.VarDesc.VarType.SELECTED_ROWS,
True,
)
w_selected_rows = var.value().get_selected_rows()
w_selected_rows.set_height(self.height)
w_selected_rows.set_rows(rows)
w_tensor = w_selected_rows.get_tensor()
w_tensor.set(w_array, place)
return var
def create_lod_tensor(self, place):
w_array = self._get_array(
[i for i in range(self.height)], self.row_numel
)
return paddle.to_tensor(w_array)
def test_w_is_selected_rows(self):
places = [core.XPUPlace(0)]
for place in places:
self.check_with_place(place, True)
support_types = get_xpu_op_support_types('sum')
for stype in support_types:
create_test_class(globals(), XPUTestSumOp, stype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册