提交 012d45da 编写于 作者: Z zyfncg 提交者: GitHub

Revert "【Pten】Adjust the Empyt dev_api (#39143)"

This reverts commit 9d4d0c3b.
上级 9d4d0c3b
......@@ -32,7 +32,7 @@ void EmptyLikeKernel(const Context& dev_ctx, DenseTensor* out);
// TODO(chenweihang): the tensor creation method need to be replaced later,
// all kernel api call Empty here instead of making tensor self
template <typename Context>
template <typename T, typename Context>
DenseTensor Empty(const Context& dev_ctx, DenseTensorMeta&& meta) {
pten::DenseTensor dense_out(
pten::make_intrusive<paddle::experimental::SharedStorage>(
......@@ -43,7 +43,7 @@ DenseTensor Empty(const Context& dev_ctx, DenseTensorMeta&& meta) {
template <typename T, typename Context>
DenseTensor Empty(const Context& dev_ctx) {
return Empty(dev_ctx,
return Empty<T, Context>(dev_ctx,
{paddle::experimental::CppTypeToDataType<T>::Type(),
{-1},
DataLayout::NCHW});
......
......@@ -392,7 +392,7 @@ static inline void GetDoubleGradSafeTensor(const DeviceContext &dev_ctx,
*ddx_safe = *ddx;
} else {
auto meta = pten::DenseTensorMeta(x.dtype(), x.dims(), x.layout());
*ddx_safe = pten::Empty(dev_ctx, std::move(meta));
*ddx_safe = pten::Empty<T, DeviceContext>(dev_ctx, std::move(meta));
ddx_safe->mutable_data(dev_ctx.GetPlace());
paddle::operators::math::SetConstant<DeviceContext, T> set_zero;
set_zero(dev_ctx, ddx_safe, static_cast<T>(0));
......
......@@ -76,8 +76,10 @@ void DenseToSparseCooKernel(const Context& dev_ctx,
{sparse_dim, static_cast<int64_t>(non_zero_num)},
DataLayout::NCHW);
DenseTensorMeta values_meta(x.meta().dtype, values_dims, x.meta().layout);
pten::DenseTensor indices = pten::Empty(dev_ctx, std::move(indices_meta));
pten::DenseTensor values = pten::Empty(dev_ctx, std::move(values_meta));
pten::DenseTensor indices =
pten::Empty<int64_t, Context>(dev_ctx, std::move(indices_meta));
pten::DenseTensor values =
pten::Empty<T, Context>(dev_ctx, std::move(values_meta));
int64_t* indices_data = indices.mutable_data<int64_t>(place);
T* values_data = values.mutable_data<T>(place);
......@@ -121,8 +123,10 @@ void SparseCsrToCooKernel(const Context& dev_ctx,
DenseTensorMeta indices_meta(
DataType::INT64, {sparse_dim, non_zero_num}, DataLayout::NCHW);
DenseTensorMeta values_meta(x.dtype(), {non_zero_num}, x.layout());
pten::DenseTensor indices = pten::Empty(dev_ctx, std::move(indices_meta));
pten::DenseTensor values = pten::Empty(dev_ctx, std::move(values_meta));
pten::DenseTensor indices =
pten::Empty<int64_t, Context>(dev_ctx, std::move(indices_meta));
pten::DenseTensor values =
pten::Empty<T, Context>(dev_ctx, std::move(values_meta));
int64_t* coo_indices = indices.mutable_data<int64_t>(place);
int64_t* batch_ptr = x_dims.size() == 2 ? nullptr : coo_indices;
int64_t* coo_rows_data =
......
......@@ -111,12 +111,14 @@ void DenseToSparseCooKernel(const Context& dev_ctx,
const int cols = dims_2d[1];
auto nums_meta =
pten::DenseTensorMeta(DataType::INT32, {1}, pten::DataLayout::NCHW);
DenseTensor nums = pten::Empty(dev_ctx, std::move(nums_meta));
DenseTensor nums =
pten::Empty<int64_t, Context>(dev_ctx, std::move(nums_meta));
auto x_dims_meta =
pten::DenseTensorMeta(DataType::INT64,
{static_cast<int64_t>(x_dims.size())},
pten::DataLayout::NCHW);
DenseTensor d_x_dims = pten::Empty(dev_ctx, std::move(x_dims_meta));
DenseTensor d_x_dims =
pten::Empty<T, Context>(dev_ctx, std::move(x_dims_meta));
const auto place = dev_ctx.GetPlace();
......@@ -134,7 +136,8 @@ void DenseToSparseCooKernel(const Context& dev_ctx,
auto temp_indexs_meta =
pten::DenseTensorMeta(DataType::INT32, {rows}, pten::DataLayout::NCHW);
DenseTensor temp_indexs = pten::Empty(dev_ctx, std::move(temp_indexs_meta));
DenseTensor temp_indexs =
pten::Empty<T, Context>(dev_ctx, std::move(temp_indexs_meta));
int* temp_indexs_ptr = temp_indexs.mutable_data<int>(place);
GetNonZeroNums<<<grid_size, block_size, 0, dev_ctx.stream()>>>(
x_data, rows, cols, nums_ptr, temp_indexs_ptr);
......@@ -266,9 +269,11 @@ void SparseCsrToCooKernel(const Context& dev_ctx,
DataType::INT64, {sparse_dim, non_zero_num}, DataLayout::NCHW);
DenseTensorMeta values_meta(x.dtype(), {non_zero_num}, x.layout());
DenseTensorMeta offsets_meta(DataType::INT32, {batchs}, DataLayout::NCHW);
DenseTensor indices = pten::Empty(dev_ctx, std::move(indices_meta));
DenseTensor values = pten::Empty(dev_ctx, std::move(values_meta));
DenseTensor offsets = pten::Empty(dev_ctx, std::move(offsets_meta));
DenseTensor indices =
pten::Empty<int64_t, Context>(dev_ctx, std::move(indices_meta));
DenseTensor values = pten::Empty<T, Context>(dev_ctx, std::move(values_meta));
DenseTensor offsets =
pten::Empty<T, Context>(dev_ctx, std::move(offsets_meta));
int64_t* coo_indices = indices.mutable_data<int64_t>(place);
int64_t* batch_ptr = x_dims.size() == 2 ? nullptr : coo_indices;
int64_t* coo_rows_data =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册