未验证 提交 8363406a 编写于 作者: C chentianyu03 提交者: GitHub

[pten]add T, remove default value of DataType in DeviceContext::Alloc (#39620)

* add T to Alloc and remove default value of DataType in DeviceContext::Alloc

* add dtype
上级 03b875a8
......@@ -99,9 +99,7 @@ class DeviceContext {
/**
* @brief Allocate device memory for tensor.
*/
void* Alloc(TensorBase*,
DataType dtype = DataType::UNDEFINED,
size_t requested_size = 0) const;
void* Alloc(TensorBase*, DataType dtype, size_t requested_size = 0) const;
template <typename T>
T* Alloc(TensorBase* tensor, size_t requested_size = 0) const;
......@@ -110,7 +108,7 @@ class DeviceContext {
* @brief Allocate host memory for tensor.
*/
void* HostAlloc(TensorBase* tensor,
DataType dtype = DataType::UNDEFINED,
DataType dtype,
size_t requested_size = 0) const;
template <typename T>
......
......@@ -37,7 +37,7 @@ void Copy(const Context& dev_ctx,
<< src_place;
dst->Resize(src.dims());
auto* dst_ptr = dev_ctx.Alloc(dst);
auto* dst_ptr = dev_ctx.Alloc(dst, src.dtype());
if (src_ptr == dst_ptr) {
VLOG(3) << "Skip copy the same data async from " << src_place << " to "
......
......@@ -44,7 +44,7 @@ void SplitKernel(const Context& dev_ctx,
std::vector<const DenseTensor*> shape_refer;
for (size_t j = 0; j < outs.size(); ++j) {
dev_ctx.Alloc(outs[j]);
dev_ctx.template Alloc<T>(outs[j]);
shape_refer.emplace_back(outs[j]);
}
......
......@@ -43,7 +43,7 @@ void SplitKernel(const Context& dev_ctx,
std::vector<const DenseTensor*> shape_refer;
for (size_t j = 0; j < outs.size(); ++j) {
dev_ctx.Alloc(outs[j]);
dev_ctx.template Alloc<T>(outs[j]);
shape_refer.emplace_back(outs[j]);
}
......
......@@ -29,10 +29,10 @@ void ReshapeKernel(const Context& dev_ctx,
MetaTensor meta_out(out);
InferMetaFromVecValue(x, shape.GetData(), &meta_out);
if (x.initialized() && x.Holder() == out->Holder()) {
dev_ctx.Alloc(out);
dev_ctx.Alloc(out, x.dtype());
return;
}
dev_ctx.Alloc(out);
dev_ctx.Alloc(out, x.dtype());
// TODO(chenweihang): the output dims are overwrite after copying,
// here we need to use copy method that only copy data
auto dims = out->dims();
......
......@@ -30,7 +30,7 @@ void Copy(const Context& dev_ctx,
bool blocking,
DenseTensor* dst) {
auto* src_ptr = src.data();
auto* dst_ptr = dev_ctx.Alloc(dst);
auto* dst_ptr = dev_ctx.Alloc(dst, src.dtype());
const auto& src_place = src.place();
const auto& dst_place = dst->place();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册