未验证 提交 9d4d0c3b 编写于 作者: Z zyfncg 提交者: GitHub

【Pten】Adjust the Empyt dev_api (#39143)

* adjust the Empyt dev_api

* fix merge conflict

* fix sparse_utils_kernel
上级 87f4a681
...@@ -32,7 +32,7 @@ void EmptyLikeKernel(const Context& dev_ctx, DenseTensor* out); ...@@ -32,7 +32,7 @@ void EmptyLikeKernel(const Context& dev_ctx, DenseTensor* out);
// TODO(chenweihang): the tensor creation method need to be replaced later, // TODO(chenweihang): the tensor creation method need to be replaced later,
// all kernel api call Empty here instead of making tensor self // all kernel api call Empty here instead of making tensor self
template <typename T, typename Context> template <typename Context>
DenseTensor Empty(const Context& dev_ctx, DenseTensorMeta&& meta) { DenseTensor Empty(const Context& dev_ctx, DenseTensorMeta&& meta) {
pten::DenseTensor dense_out( pten::DenseTensor dense_out(
pten::make_intrusive<paddle::experimental::SharedStorage>( pten::make_intrusive<paddle::experimental::SharedStorage>(
...@@ -43,10 +43,10 @@ DenseTensor Empty(const Context& dev_ctx, DenseTensorMeta&& meta) { ...@@ -43,10 +43,10 @@ DenseTensor Empty(const Context& dev_ctx, DenseTensorMeta&& meta) {
template <typename T, typename Context> template <typename T, typename Context>
DenseTensor Empty(const Context& dev_ctx) { DenseTensor Empty(const Context& dev_ctx) {
return Empty<T, Context>(dev_ctx, return Empty(dev_ctx,
{paddle::experimental::CppTypeToDataType<T>::Type(), {paddle::experimental::CppTypeToDataType<T>::Type(),
{-1}, {-1},
DataLayout::NCHW}); DataLayout::NCHW});
} }
template <typename T, typename Context> template <typename T, typename Context>
......
...@@ -392,7 +392,7 @@ static inline void GetDoubleGradSafeTensor(const DeviceContext &dev_ctx, ...@@ -392,7 +392,7 @@ static inline void GetDoubleGradSafeTensor(const DeviceContext &dev_ctx,
*ddx_safe = *ddx; *ddx_safe = *ddx;
} else { } else {
auto meta = pten::DenseTensorMeta(x.dtype(), x.dims(), x.layout()); auto meta = pten::DenseTensorMeta(x.dtype(), x.dims(), x.layout());
*ddx_safe = pten::Empty<T, DeviceContext>(dev_ctx, std::move(meta)); *ddx_safe = pten::Empty(dev_ctx, std::move(meta));
ddx_safe->mutable_data(dev_ctx.GetPlace()); ddx_safe->mutable_data(dev_ctx.GetPlace());
paddle::operators::math::SetConstant<DeviceContext, T> set_zero; paddle::operators::math::SetConstant<DeviceContext, T> set_zero;
set_zero(dev_ctx, ddx_safe, static_cast<T>(0)); set_zero(dev_ctx, ddx_safe, static_cast<T>(0));
......
...@@ -76,10 +76,8 @@ void DenseToSparseCooKernel(const Context& dev_ctx, ...@@ -76,10 +76,8 @@ void DenseToSparseCooKernel(const Context& dev_ctx,
{sparse_dim, static_cast<int64_t>(non_zero_num)}, {sparse_dim, static_cast<int64_t>(non_zero_num)},
DataLayout::NCHW); DataLayout::NCHW);
DenseTensorMeta values_meta(x.meta().dtype, values_dims, x.meta().layout); DenseTensorMeta values_meta(x.meta().dtype, values_dims, x.meta().layout);
pten::DenseTensor indices = pten::DenseTensor indices = pten::Empty(dev_ctx, std::move(indices_meta));
pten::Empty<int64_t, Context>(dev_ctx, std::move(indices_meta)); pten::DenseTensor values = pten::Empty(dev_ctx, std::move(values_meta));
pten::DenseTensor values =
pten::Empty<T, Context>(dev_ctx, std::move(values_meta));
int64_t* indices_data = indices.mutable_data<int64_t>(place); int64_t* indices_data = indices.mutable_data<int64_t>(place);
T* values_data = values.mutable_data<T>(place); T* values_data = values.mutable_data<T>(place);
...@@ -123,10 +121,8 @@ void SparseCsrToCooKernel(const Context& dev_ctx, ...@@ -123,10 +121,8 @@ void SparseCsrToCooKernel(const Context& dev_ctx,
DenseTensorMeta indices_meta( DenseTensorMeta indices_meta(
DataType::INT64, {sparse_dim, non_zero_num}, DataLayout::NCHW); DataType::INT64, {sparse_dim, non_zero_num}, DataLayout::NCHW);
DenseTensorMeta values_meta(x.dtype(), {non_zero_num}, x.layout()); DenseTensorMeta values_meta(x.dtype(), {non_zero_num}, x.layout());
pten::DenseTensor indices = pten::DenseTensor indices = pten::Empty(dev_ctx, std::move(indices_meta));
pten::Empty<int64_t, Context>(dev_ctx, std::move(indices_meta)); pten::DenseTensor values = pten::Empty(dev_ctx, std::move(values_meta));
pten::DenseTensor values =
pten::Empty<T, Context>(dev_ctx, std::move(values_meta));
int64_t* coo_indices = indices.mutable_data<int64_t>(place); int64_t* coo_indices = indices.mutable_data<int64_t>(place);
int64_t* batch_ptr = x_dims.size() == 2 ? nullptr : coo_indices; int64_t* batch_ptr = x_dims.size() == 2 ? nullptr : coo_indices;
int64_t* coo_rows_data = int64_t* coo_rows_data =
......
...@@ -111,14 +111,12 @@ void DenseToSparseCooKernel(const Context& dev_ctx, ...@@ -111,14 +111,12 @@ void DenseToSparseCooKernel(const Context& dev_ctx,
const int cols = dims_2d[1]; const int cols = dims_2d[1];
auto nums_meta = auto nums_meta =
pten::DenseTensorMeta(DataType::INT32, {1}, pten::DataLayout::NCHW); pten::DenseTensorMeta(DataType::INT32, {1}, pten::DataLayout::NCHW);
DenseTensor nums = DenseTensor nums = pten::Empty(dev_ctx, std::move(nums_meta));
pten::Empty<int64_t, Context>(dev_ctx, std::move(nums_meta));
auto x_dims_meta = auto x_dims_meta =
pten::DenseTensorMeta(DataType::INT64, pten::DenseTensorMeta(DataType::INT64,
{static_cast<int64_t>(x_dims.size())}, {static_cast<int64_t>(x_dims.size())},
pten::DataLayout::NCHW); pten::DataLayout::NCHW);
DenseTensor d_x_dims = DenseTensor d_x_dims = pten::Empty(dev_ctx, std::move(x_dims_meta));
pten::Empty<T, Context>(dev_ctx, std::move(x_dims_meta));
const auto place = dev_ctx.GetPlace(); const auto place = dev_ctx.GetPlace();
...@@ -136,8 +134,7 @@ void DenseToSparseCooKernel(const Context& dev_ctx, ...@@ -136,8 +134,7 @@ void DenseToSparseCooKernel(const Context& dev_ctx,
auto temp_indexs_meta = auto temp_indexs_meta =
pten::DenseTensorMeta(DataType::INT32, {rows}, pten::DataLayout::NCHW); pten::DenseTensorMeta(DataType::INT32, {rows}, pten::DataLayout::NCHW);
DenseTensor temp_indexs = DenseTensor temp_indexs = pten::Empty(dev_ctx, std::move(temp_indexs_meta));
pten::Empty<T, Context>(dev_ctx, std::move(temp_indexs_meta));
int* temp_indexs_ptr = temp_indexs.mutable_data<int>(place); int* temp_indexs_ptr = temp_indexs.mutable_data<int>(place);
GetNonZeroNums<<<grid_size, block_size, 0, dev_ctx.stream()>>>( GetNonZeroNums<<<grid_size, block_size, 0, dev_ctx.stream()>>>(
x_data, rows, cols, nums_ptr, temp_indexs_ptr); x_data, rows, cols, nums_ptr, temp_indexs_ptr);
...@@ -269,11 +266,9 @@ void SparseCsrToCooKernel(const Context& dev_ctx, ...@@ -269,11 +266,9 @@ void SparseCsrToCooKernel(const Context& dev_ctx,
DataType::INT64, {sparse_dim, non_zero_num}, DataLayout::NCHW); DataType::INT64, {sparse_dim, non_zero_num}, DataLayout::NCHW);
DenseTensorMeta values_meta(x.dtype(), {non_zero_num}, x.layout()); DenseTensorMeta values_meta(x.dtype(), {non_zero_num}, x.layout());
DenseTensorMeta offsets_meta(DataType::INT32, {batchs}, DataLayout::NCHW); DenseTensorMeta offsets_meta(DataType::INT32, {batchs}, DataLayout::NCHW);
DenseTensor indices = DenseTensor indices = pten::Empty(dev_ctx, std::move(indices_meta));
pten::Empty<int64_t, Context>(dev_ctx, std::move(indices_meta)); DenseTensor values = pten::Empty(dev_ctx, std::move(values_meta));
DenseTensor values = pten::Empty<T, Context>(dev_ctx, std::move(values_meta)); DenseTensor offsets = pten::Empty(dev_ctx, std::move(offsets_meta));
DenseTensor offsets =
pten::Empty<T, Context>(dev_ctx, std::move(offsets_meta));
int64_t* coo_indices = indices.mutable_data<int64_t>(place); int64_t* coo_indices = indices.mutable_data<int64_t>(place);
int64_t* batch_ptr = x_dims.size() == 2 ? nullptr : coo_indices; int64_t* batch_ptr = x_dims.size() == 2 ? nullptr : coo_indices;
int64_t* coo_rows_data = int64_t* coo_rows_data =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册