未验证 提交 8363406a 编写于 作者: C chentianyu03 提交者: GitHub

[pten]add T, remove default value of DataType in DeviceContext::Alloc (#39620)

* add T to Alloc and remove default value of DataType in DeviceContext::Alloc

* add dtype
上级 03b875a8
...@@ -99,9 +99,7 @@ class DeviceContext { ...@@ -99,9 +99,7 @@ class DeviceContext {
/** /**
* @brief Allocate device memory for tensor. * @brief Allocate device memory for tensor.
*/ */
void* Alloc(TensorBase*, void* Alloc(TensorBase*, DataType dtype, size_t requested_size = 0) const;
DataType dtype = DataType::UNDEFINED,
size_t requested_size = 0) const;
template <typename T> template <typename T>
T* Alloc(TensorBase* tensor, size_t requested_size = 0) const; T* Alloc(TensorBase* tensor, size_t requested_size = 0) const;
...@@ -110,7 +108,7 @@ class DeviceContext { ...@@ -110,7 +108,7 @@ class DeviceContext {
* @brief Allocate host memory for tensor. * @brief Allocate host memory for tensor.
*/ */
void* HostAlloc(TensorBase* tensor, void* HostAlloc(TensorBase* tensor,
DataType dtype = DataType::UNDEFINED, DataType dtype,
size_t requested_size = 0) const; size_t requested_size = 0) const;
template <typename T> template <typename T>
......
...@@ -37,7 +37,7 @@ void Copy(const Context& dev_ctx, ...@@ -37,7 +37,7 @@ void Copy(const Context& dev_ctx,
<< src_place; << src_place;
dst->Resize(src.dims()); dst->Resize(src.dims());
auto* dst_ptr = dev_ctx.Alloc(dst); auto* dst_ptr = dev_ctx.Alloc(dst, src.dtype());
if (src_ptr == dst_ptr) { if (src_ptr == dst_ptr) {
VLOG(3) << "Skip copy the same data async from " << src_place << " to " VLOG(3) << "Skip copy the same data async from " << src_place << " to "
......
...@@ -44,7 +44,7 @@ void SplitKernel(const Context& dev_ctx, ...@@ -44,7 +44,7 @@ void SplitKernel(const Context& dev_ctx,
std::vector<const DenseTensor*> shape_refer; std::vector<const DenseTensor*> shape_refer;
for (size_t j = 0; j < outs.size(); ++j) { for (size_t j = 0; j < outs.size(); ++j) {
dev_ctx.Alloc(outs[j]); dev_ctx.template Alloc<T>(outs[j]);
shape_refer.emplace_back(outs[j]); shape_refer.emplace_back(outs[j]);
} }
......
...@@ -43,7 +43,7 @@ void SplitKernel(const Context& dev_ctx, ...@@ -43,7 +43,7 @@ void SplitKernel(const Context& dev_ctx,
std::vector<const DenseTensor*> shape_refer; std::vector<const DenseTensor*> shape_refer;
for (size_t j = 0; j < outs.size(); ++j) { for (size_t j = 0; j < outs.size(); ++j) {
dev_ctx.Alloc(outs[j]); dev_ctx.template Alloc<T>(outs[j]);
shape_refer.emplace_back(outs[j]); shape_refer.emplace_back(outs[j]);
} }
......
...@@ -29,10 +29,10 @@ void ReshapeKernel(const Context& dev_ctx, ...@@ -29,10 +29,10 @@ void ReshapeKernel(const Context& dev_ctx,
MetaTensor meta_out(out); MetaTensor meta_out(out);
InferMetaFromVecValue(x, shape.GetData(), &meta_out); InferMetaFromVecValue(x, shape.GetData(), &meta_out);
if (x.initialized() && x.Holder() == out->Holder()) { if (x.initialized() && x.Holder() == out->Holder()) {
dev_ctx.Alloc(out); dev_ctx.Alloc(out, x.dtype());
return; return;
} }
dev_ctx.Alloc(out); dev_ctx.Alloc(out, x.dtype());
// TODO(chenweihang): the output dims are overwrite after copying, // TODO(chenweihang): the output dims are overwrite after copying,
// here we need to use copy method that only copy data // here we need to use copy method that only copy data
auto dims = out->dims(); auto dims = out->dims();
......
...@@ -30,7 +30,7 @@ void Copy(const Context& dev_ctx, ...@@ -30,7 +30,7 @@ void Copy(const Context& dev_ctx,
bool blocking, bool blocking,
DenseTensor* dst) { DenseTensor* dst) {
auto* src_ptr = src.data(); auto* src_ptr = src.data();
auto* dst_ptr = dev_ctx.Alloc(dst); auto* dst_ptr = dev_ctx.Alloc(dst, src.dtype());
const auto& src_place = src.place(); const auto& src_place = src.place();
const auto& dst_place = dst->place(); const auto& dst_place = dst->place();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册