未验证 提交 5631da9c 编写于 作者: A Aurelius84 提交者: GitHub

[PTen]Support AllocateFrom in Tensor and Alloc/HostAlloc in Context (#39022)

* Support allocate_from in Tensor and allocate_data in Context

* fix #ifdef CUDA

* fix cycle depends

* fix test_xxx_dev_api failed

* fix windows compiling error

* fix unittest

* modify into PImpl

* fix selected rows

* add TODO comment

* refine interface according reviewer
上级 f3f16126
......@@ -840,6 +840,28 @@ void* AllocatorFacade::GetBasePtr(
return m_->GetBasePtr(allocation);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
const std::shared_ptr<Allocator>& AllocatorFacade::GetAllocator(
const platform::Place& place, const gpuStream_t& stream) {
if (FLAGS_use_stream_safe_cuda_allocator && platform::is_gpu_place(place) &&
FLAGS_use_system_allocator == false) {
#ifdef PADDLE_WITH_CUDA
if (UNLIKELY(platform::CUDAGraph::IsCapturing())) {
return m_->GetAllocator(place,
/* A non-zero num to choose allocator_ */ 1);
}
#endif
return m_->GetAllocator(place, stream, /*create_if_not_found=*/true);
}
return m_->GetAllocator(place, /* A non-zero num to choose allocator_ */ 1);
}
#endif
const std::shared_ptr<Allocator>& AllocatorFacade::GetZeroAllocator(
const platform::Place& place) {
return m_->GetAllocator(place, /* zero size */ 0);
}
std::shared_ptr<pten::Allocation> AllocatorFacade::AllocShared(
const platform::Place& place, size_t size) {
return std::shared_ptr<pten::Allocation>(Alloc(place, size));
......
......@@ -53,6 +53,14 @@ class AllocatorFacade {
void* GetBasePtr(const std::shared_ptr<Allocation>& allocation);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
const std::shared_ptr<Allocator>& GetAllocator(const platform::Place& place,
const gpuStream_t& stream);
#endif
const std::shared_ptr<Allocator>& GetZeroAllocator(
const platform::Place& place);
// Allocate a shared allocation.
std::shared_ptr<Allocation> AllocShared(const platform::Place& place,
size_t size);
......
......@@ -23,6 +23,7 @@ limitations under the License. */
#endif
#include "glog/logging.h"
#include "paddle/fluid/framework/expect.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
......@@ -136,11 +137,39 @@ inline void EmplaceDeviceContext(
map_ptr,
platform::Place p) {
using PtrType = std::unique_ptr<DeviceContext>;
map_ptr->emplace(p, std::async(std::launch::deferred, [=] {
// lazy evaluation. i.e., only create device context at
// first `Get`
return PtrType(new DevCtx(p));
}));
map_ptr->emplace(
p, std::async(std::launch::deferred, [=] {
// lazy evaluation. i.e., only create device context at
// first `Get`
auto* dev_ctx = new DevCtx(p);
if (is_gpu_place(p)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto* cuda_ctx = dynamic_cast<CUDADeviceContext*>(dev_ctx);
PADDLE_ENFORCE_NOT_NULL(
cuda_ctx,
platform::errors::InvalidArgument(
"Failed to dynamic_cast dev_ctx into CUDADeviceContext."));
dev_ctx->SetDeviceAllocator(
memory::allocation::AllocatorFacade::Instance()
.GetAllocator(p, cuda_ctx->context()->RawStream())
.get());
#endif
} else {
dev_ctx->SetDeviceAllocator(
memory::allocation::AllocatorFacade::Instance()
.GetAllocator(p)
.get());
}
dev_ctx->SetHostAllocator(
memory::allocation::AllocatorFacade::Instance()
.GetAllocator(platform::CPUPlace())
.get());
dev_ctx->SetZeroAllocator(
memory::allocation::AllocatorFacade::Instance()
.GetZeroAllocator(p)
.get());
return PtrType(dev_ctx);
}));
}
DeviceContextPool::DeviceContextPool(
......
......@@ -68,6 +68,45 @@ bool DenseTensor::IsSharedWith(const DenseTensor& b) const {
return holder_ && holder_ == b.Holder();
}
void* DenseTensor::AllocateFrom(Allocator* allocator,
DataType dtype,
size_t requested_size) {
PADDLE_ENFORCE_NOT_NULL(
allocator,
paddle::platform::errors::InvalidArgument(
"Required allocator shall not be nullptr, but received nullptr."));
if (this->dtype() != dtype) {
VLOG(10) << "change data type in mutbale_data, target dtype - " << dtype;
meta_.dtype = dtype;
}
PADDLE_ENFORCE(
valid(),
paddle::platform::errors::PreconditionNotMet(
"The meta data must be valid when call the mutable data function."));
size_t bytes = numel() * SizeOf(this->dtype());
if (requested_size) {
PADDLE_ENFORCE_GE(requested_size,
bytes,
paddle::platform::errors::InvalidArgument(
"The reserved size %d should be enough to meet the "
"volume required by metadata %d.",
requested_size,
bytes));
bytes = requested_size;
}
// TODO(paddle-dev): In case of the allocator of storage_ is different with
// the incoming allocator, we should re-alloc data using the incoming
// allocator.
if (!holder_ || holder_->size() < bytes + meta_.offset) {
meta_.offset = 0;
VLOG(10) << "Allocate data with bytes: " << bytes;
ResetHolder(allocator->Allocate(bytes));
}
return reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
meta_.offset);
}
template <typename T>
const T* DenseTensor::data() const {
check_memory_size();
......
......@@ -124,6 +124,12 @@ class DenseTensor : public TensorBase,
/// return Whether the storage is allocated.
bool initialized() const override { return holder_ && holder_->ptr(); }
/// \brief Allocate memory with requested size from allocator.
/// \return The mutable data pointer value of type T.
void* AllocateFrom(Allocator* allocator,
DataType dtype,
size_t requested_size = 0) override;
/// \brief Check if storage is shared with other objects.
/// \return Whether the storage is shared with other objects.
bool IsSharedWith(const DenseTensor& b) const;
......
......@@ -13,45 +13,119 @@
// limitations under the License.
#include "paddle/pten/core/device_context.h"
#include "paddle/pten/api/ext/exception.h"
#include "paddle/pten/core/enforce.h"
#include "paddle/pten/core/tensor_base.h"
namespace pten {
using DataType = paddle::experimental::DataType;
struct DeviceContext::Impl {
Impl() = default;
~Impl() = default;
void SetDeviceAllocator(Allocator* allocator) {
void SetDeviceAllocator(const Allocator* allocator) {
PADDLE_ENFORCE_NOT_NULL(
allocator,
pten::errors::InvalidArgument(
"Required allocator shall not be nullptr, but received nullptr."));
device_allocator_ = allocator;
}
void SetHostAllocator(Allocator* allocator) { host_allocator_ = allocator; }
void SetHostAllocator(const Allocator* allocator) {
PADDLE_ENFORCE_NOT_NULL(
allocator,
pten::errors::InvalidArgument(
"Required allocator shall not be nullptr, but received nullptr."));
host_allocator_ = allocator;
}
void SetZeroAllocator(const Allocator* allocator) {
PADDLE_ENFORCE_NOT_NULL(
allocator,
pten::errors::InvalidArgument(
"Required allocator shall not be nullptr, but received nullptr."));
zero_allocator_ = allocator;
}
const Allocator& GetDeviceAllocator() const {
PD_CHECK(device_allocator_ != nullptr, "the device_allocator is nullptr.");
PADDLE_ENFORCE_NOT_NULL(
device_allocator_,
pten::errors::InvalidArgument("Required device_allocator_ shall not be "
"nullptr, but received nullptr."));
return *device_allocator_;
}
const Allocator& GetHostAllocator() const {
PD_CHECK(host_allocator_ != nullptr, "the host_allocator is nullptr.");
PADDLE_ENFORCE_NOT_NULL(
host_allocator_,
pten::errors::InvalidArgument("Required host_allocator_ shall not be "
"nullptr, but received nullptr."));
return *host_allocator_;
}
// TODO(Wilber): Add impl. It seems that tensorbase not have interface to
// communicate with allocator.
void HostAlloc(TensorBase* tensor) {}
void DeviceAlloc(TensorBase* tensor) {}
const Allocator& GetZeroAllocator() const {
PADDLE_ENFORCE_NOT_NULL(
zero_allocator_,
pten::errors::InvalidArgument("Required host_allocator_ shall not be "
"nullptr, but received nullptr."));
return *zero_allocator_;
}
void* Alloc(TensorBase* tensor,
DataType dtype = DataType::UNDEFINED,
size_t requested_size = 0) const {
PADDLE_ENFORCE_NOT_NULL(
tensor,
pten::errors::InvalidArgument(
"Required tensor shall not be nullptr, but received nullptr."));
if (dtype == DataType::UNDEFINED) {
dtype = tensor->dtype();
}
auto* allocator =
tensor->numel() == 0 ? zero_allocator_ : device_allocator_;
return tensor->AllocateFrom(
const_cast<Allocator*>(allocator), dtype, requested_size);
}
template <typename T>
T* Alloc(TensorBase* tensor, size_t requested_size = 0) const {
DataType dtype = paddle::experimental::CppTypeToDataType<T>::Type();
return static_cast<T*>(Alloc(tensor, dtype, requested_size));
}
Allocator* device_allocator_{nullptr};
Allocator* host_allocator_{nullptr};
void* HostAlloc(TensorBase* tensor,
DataType dtype = DataType::UNDEFINED,
size_t requested_size = 0) const {
PADDLE_ENFORCE_NOT_NULL(
tensor,
pten::errors::InvalidArgument(
"Required tensor shall not be nullptr, but received nullptr."));
if (dtype == DataType::UNDEFINED) {
dtype = tensor->dtype();
}
auto* allocator = tensor->numel() == 0 ? zero_allocator_ : host_allocator_;
return tensor->AllocateFrom(
const_cast<Allocator*>(allocator), dtype, requested_size);
}
template <typename T>
T* HostAlloc(pten::TensorBase* tensor, size_t requested_size = 0) const {
DataType dtype = paddle::experimental::CppTypeToDataType<T>::Type();
return static_cast<T*>(HostAlloc(tensor, dtype, requested_size));
}
private:
const Allocator* device_allocator_{nullptr};
const Allocator* host_allocator_{nullptr};
const Allocator* zero_allocator_{nullptr};
};
DeviceContext::DeviceContext() { impl_ = std::make_unique<Impl>(); }
DeviceContext::DeviceContext(const DeviceContext& other) {
impl_->SetDeviceAllocator(
const_cast<Allocator*>(&other.GetDeviceAllocator()));
impl_->SetHostAllocator(const_cast<Allocator*>(&other.GetHostAllocator()));
impl_->SetHostAllocator(&other.GetHostAllocator());
impl_->SetDeviceAllocator(&other.GetDeviceAllocator());
impl_->SetZeroAllocator(&other.GetZeroAllocator());
}
DeviceContext::DeviceContext(DeviceContext&& other) {
......@@ -60,26 +134,71 @@ DeviceContext::DeviceContext(DeviceContext&& other) {
DeviceContext::~DeviceContext() = default;
void DeviceContext::SetHostAllocator(Allocator* allocator) {
impl_->SetHostAllocator(allocator);
void DeviceContext::SetDeviceAllocator(const Allocator* allocator) {
impl_->SetDeviceAllocator(allocator);
}
void DeviceContext::SetDeviceAllocator(Allocator* allocator) {
impl_->SetDeviceAllocator(allocator);
const Allocator& DeviceContext::GetDeviceAllocator() const {
return impl_->GetDeviceAllocator();
}
void DeviceContext::SetHostAllocator(const Allocator* allocator) {
impl_->SetHostAllocator(allocator);
}
const Allocator& DeviceContext::GetHostAllocator() const {
return impl_->GetHostAllocator();
}
const Allocator& DeviceContext::GetDeviceAllocator() const {
return impl_->GetDeviceAllocator();
void DeviceContext::SetZeroAllocator(const Allocator* allocator) {
impl_->SetZeroAllocator(allocator);
}
void DeviceContext::HostAlloc(TensorBase* tensor) { impl_->HostAlloc(tensor); }
const Allocator& DeviceContext::GetZeroAllocator() const {
return impl_->GetZeroAllocator();
}
void DeviceContext::DeviceAlloc(TensorBase* tensor) {
impl_->DeviceAlloc(tensor);
void* DeviceContext::Alloc(TensorBase* tensor,
DataType dtype,
size_t requested_size) const {
return impl_->Alloc(tensor, dtype, requested_size);
}
template <typename T>
T* DeviceContext::Alloc(TensorBase* tensor, size_t requested_size) const {
return impl_->Alloc<T>(tensor, requested_size);
}
void* DeviceContext::HostAlloc(TensorBase* tensor,
DataType dtype,
size_t requested_size) const {
return impl_->HostAlloc(tensor, dtype, requested_size);
}
template <typename T>
T* DeviceContext::HostAlloc(TensorBase* tensor, size_t requested_size) const {
return impl_->HostAlloc<T>(tensor, requested_size);
}
#define DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(dtype) \
template dtype* DeviceContext::Alloc(TensorBase* tensor, \
size_t requested_size) const; \
template dtype* DeviceContext::HostAlloc(TensorBase* tensor, \
size_t requested_size) const;
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(bool)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(int8_t)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(uint8_t)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(int16_t)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(int32_t)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(int64_t)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(float)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(double)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::bfloat16)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::float16)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::complex64)
DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::complex128)
#undef DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION
} // namespace pten
......@@ -19,6 +19,7 @@ limitations under the License. */
// TODO(wilber): Do we need to use place in pten kernel?
#include "paddle/pten/common/place.h"
#include "paddle/pten/common/data_type.h"
#include "paddle/pten/core/allocator.h"
namespace pten {
......@@ -31,6 +32,8 @@ class TensorBase;
* DeviceContext.
*/
class DeviceContext {
using DataType = paddle::experimental::DataType;
public:
/**
* @brief Default construct.
......@@ -53,42 +56,61 @@ class DeviceContext {
virtual ~DeviceContext();
/**
* @brief Set the deveice-releated Allocator object.
* @brief Set the device-related Allocator object.
*
* @param allocator
*/
void SetDeviceAllocator(Allocator*);
void SetDeviceAllocator(const Allocator*);
/**
* @brief Get the const deveice-releated Allocator object.
* @brief Set the host Allocator object.
*
* @return Allocator
* @param allocator
*/
const Allocator& GetDeviceAllocator() const;
void SetHostAllocator(const Allocator*);
/**
* @brief Allocate device memory for tensor.
*/
void DeviceAlloc(pten::TensorBase*);
* @brief Set the zero-size Allocator object.
*
* @param allocator
*/
void SetZeroAllocator(const Allocator*);
/**
* @brief Set the host Allocator object.
* @brief Get the const Allocator object.
*
* @param allocator
* @return Allocator
*/
void SetHostAllocator(Allocator*);
const Allocator& GetDeviceAllocator() const;
/**
* @brief Get the const host Allocator object.
* @brief Get the const device-related Allocator object.
*
* @return Allocator
*/
const Allocator& GetHostAllocator() const;
const Allocator& GetZeroAllocator() const;
/**
* @brief Allocate device memory for tensor.
*/
void* Alloc(TensorBase*,
DataType dtype = DataType::UNDEFINED,
size_t requested_size = 0) const;
template <typename T>
T* Alloc(TensorBase* tensor, size_t requested_size = 0) const;
/**
* @brief Allocate host memory for tensor.
*/
void HostAlloc(pten::TensorBase*);
void* HostAlloc(TensorBase* tensor,
DataType dtype = DataType::UNDEFINED,
size_t requested_size = 0) const;
template <typename T>
T* HostAlloc(TensorBase* tensor, size_t requested_size = 0) const;
// TODO(wilber): Just for the convenience of migrating the code, it will be
// modified or removed later.
......
......@@ -91,6 +91,12 @@ struct TensorFillVisitor {
int64_t size_;
};
void* SelectedRows::AllocateFrom(Allocator* allocator,
DataType dtype,
size_t requested_size) {
return value_->AllocateFrom(allocator, dtype, requested_size);
}
bool SelectedRows::HasKey(int64_t key) const {
return std::find(rows_.begin(), rows_.end(), key) == rows_.end() ? false
: true;
......
......@@ -113,6 +113,10 @@ class SelectedRows : public TensorBase,
bool auto_grown = false,
bool is_test = false);
void* AllocateFrom(Allocator* allocator,
DataType dtype,
size_t requested_size = 0) override;
/*
* @brief Get the index of the key from id_to_index_ map. If the key not
* exist,
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/pten/common/backend.h"
#include "paddle/pten/common/data_type.h"
#include "paddle/pten/common/layout.h"
#include "paddle/pten/core/allocator.h"
#include "paddle/pten/core/ddim.h"
#include "paddle/pten/core/storage.h"
#include "paddle/pten/core/utils/type_registry.h"
......@@ -61,6 +62,16 @@ class TensorBase {
/// return Whether the storage is allocated.
virtual bool initialized() const = 0;
// TODO(Aurelius84): This interface is under intermediate state now.
// We will remove DataType argument in the future. Please DO NOT
// rely on Datatype to much when design and implement other feature.
/// \brief Allocate memory with requested size from allocator.
/// \return The mutable data pointer value of type T.
virtual void* AllocateFrom(Allocator* allocator,
DataType dtype,
size_t requested_size = 0) = 0;
/// \brief Return the type information of the derived class to support
/// safely downcast in non-rtti environment.
/// return The type information of the derived class.
......
......@@ -36,7 +36,7 @@ void CastKernelImpl(const CPUContext& dev_ctx,
auto numel = x.numel();
auto* in_end = in_begin + numel;
auto* out_begin = out->mutable_data<OutT>(dev_ctx.GetPlace());
auto* out_begin = dev_ctx.Alloc<OutT>(out);
paddle::platform::Transform<CPUContext> trans;
trans(dev_ctx,
......
......@@ -37,7 +37,7 @@ void Copy(const Context& dev_ctx,
<< src_place;
dst->Resize(src.dims());
auto* dst_ptr = dst->mutable_data(src_place);
auto* dst_ptr = dev_ctx.Alloc(dst);
if (src_ptr == dst_ptr) {
VLOG(3) << "Skip copy the same data async from " << src_place << " to "
......
......@@ -29,7 +29,7 @@ void DotKernel(const Context& dev_ctx,
DenseTensor* out) {
auto const *x_ptr = x.data<T>(), *x_ptr_ = &x_ptr[0];
auto const *y_ptr = y.data<T>(), *y_ptr_ = &y_ptr[0];
auto* z = out->mutable_data<T>(dev_ctx.GetPlace());
T* z = dev_ctx.template Alloc<T>(out);
// Loop over the total N elements of both operands while sum-reducing every
// B pairs along the way where B is the dimension of the least ordered axis
......
......@@ -45,10 +45,8 @@ struct SameDimsAddFunctor<
const DenseTensor& y,
DenseTensor* z) {
auto blas = paddle::operators::math::GetBlas<DevCtx, T>(dev_ctx);
blas.VADD(x.numel(),
x.data<T>(),
y.data<T>(),
z->mutable_data<T>(dev_ctx.GetPlace()));
blas.VADD(
x.numel(), x.data<T>(), y.data<T>(), dev_ctx.template Alloc<T>(z));
}
};
......@@ -61,7 +59,7 @@ struct SameDimsAddFunctor<
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* z) {
z->mutable_data<T>(dev_ctx.GetPlace());
dev_ctx.template Alloc<T>(z);
auto eigen_x = pten::EigenVector<T>::Flatten(x);
auto eigen_y = pten::EigenVector<T>::Flatten(y);
auto eigen_z = pten::EigenVector<T>::Flatten(*z);
......@@ -89,10 +87,8 @@ struct SameDimsSubtractFunctor<
const DenseTensor& y,
DenseTensor* z) {
auto blas = paddle::operators::math::GetBlas<DevCtx, T>(dev_ctx);
blas.VSUB(x.numel(),
x.data<T>(),
y.data<T>(),
z->mutable_data<T>(dev_ctx.GetPlace()));
blas.VSUB(
x.numel(), x.data<T>(), y.data<T>(), dev_ctx.template Alloc<T>(z));
}
};
......@@ -147,10 +143,8 @@ struct SameDimsDivideFunctor<
const DenseTensor& y,
DenseTensor* z) {
auto blas = paddle::operators::math::GetBlas<DevCtx, T>(dev_ctx);
blas.VDIV(x.numel(),
x.data<T>(),
y.data<T>(),
z->mutable_data<T>(dev_ctx.GetPlace()));
blas.VDIV(
x.numel(), x.data<T>(), y.data<T>(), dev_ctx.template Alloc<T>(z));
}
};
......@@ -173,10 +167,8 @@ struct SameDimsMultiplyFunctor<
const DenseTensor& y,
DenseTensor* z) {
auto blas = paddle::operators::math::GetBlas<DevCtx, T>(dev_ctx);
blas.VMUL(x.numel(),
x.data<T>(),
y.data<T>(),
z->mutable_data<T>(dev_ctx.GetPlace()));
blas.VMUL(
x.numel(), x.data<T>(), y.data<T>(), dev_ctx.template Alloc<T>(z));
}
};
......@@ -241,8 +233,8 @@ void CommonGradBroadcastCPU(const DenseTensor& x,
const T* y_data = y.data<T>();
const Tout* out_data = out.data<Tout>();
const Tout* dout_data = dout.data<Tout>();
T* dx_data = dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace());
T* dy_data = dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace());
T* dx_data = dx == nullptr ? nullptr : ctx.Alloc<T>(dx);
T* dy_data = dy == nullptr ? nullptr : ctx.Alloc<T>(dy);
if (dx_data != nullptr) {
memset(dx_data, 0, dx->numel() * sizeof(T));
}
......@@ -292,7 +284,7 @@ void CommonForwardBroadcastCPU(const DenseTensor& x,
PADDLE_ENFORCE_NOT_NULL(y_data,
paddle::platform::errors::InvalidArgument(
"The input Y should not be empty."));
OutType* out_data = z->mutable_data<OutType>(ctx.GetPlace());
OutType* out_data = ctx.Alloc<OutType>(z);
const int out_size = std::accumulate(
out_dims_array, out_dims_array + max_dim, 1, std::multiplies<int>());
......@@ -373,7 +365,7 @@ void ElementwiseCompute(const CPUContext& dev_ctx,
int axis,
Functor func,
DenseTensor* z) {
z->mutable_data<OutType>(dev_ctx.GetPlace());
dev_ctx.Alloc<OutType>(z);
auto x_dims = x.dims();
auto y_dims = y.dims();
bool is_xsize_larger = true;
......@@ -677,32 +669,30 @@ void ElemwiseGradComputeWithBroadcast(const CPUContext& ctx,
return;
}
if (post == 1) {
ElemwiseGradBroadcast1CPU(
x.data<T>(),
y.data<T>(),
out.data<Tout>(),
dout.data<Tout>(),
pre,
n,
is_xsize_larger,
dx_op,
dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
ElemwiseGradBroadcast1CPU(x.data<T>(),
y.data<T>(),
out.data<Tout>(),
dout.data<Tout>(),
pre,
n,
is_xsize_larger,
dx_op,
dy_op,
dx == nullptr ? nullptr : ctx.Alloc<T>(dx),
dy == nullptr ? nullptr : ctx.Alloc<T>(dy));
} else {
ElemwiseGradBroadcast2CPU(
x.data<T>(),
y.data<T>(),
out.data<Tout>(),
dout.data<Tout>(),
pre,
n,
post,
is_xsize_larger,
dx_op,
dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
ElemwiseGradBroadcast2CPU(x.data<T>(),
y.data<T>(),
out.data<Tout>(),
dout.data<Tout>(),
pre,
n,
post,
is_xsize_larger,
dx_op,
dy_op,
dx == nullptr ? nullptr : ctx.Alloc<T>(dx),
dy == nullptr ? nullptr : ctx.Alloc<T>(dy));
}
}
......
......@@ -37,7 +37,7 @@ namespace pten {
const DenseTensor& y, \
int axis, \
DenseTensor* out) { \
out->mutable_data<T>(dev_ctx.GetPlace()); \
dev_ctx.template Alloc<T>(out); \
if (x.dims() == y.dims()) { \
SameDimsElementwiseCompute<SameDims##name##Functor<CPUContext, T>>()( \
dev_ctx, x, y, out); \
......@@ -85,7 +85,7 @@ void DivideRawKernel(const Context& dev_ctx,
int axis,
DenseTensor* out) {
// allocate memory for out
out->mutable_data<T>(dev_ctx.GetPlace());
dev_ctx.template Alloc<T>(out);
if (x.dims() == y.dims() && std::is_floating_point<T>::value) {
SameDimsElementwiseCompute<SameDimsDivideFunctor<CPUContext, T>>()(
dev_ctx, x, y, out);
......
......@@ -119,7 +119,7 @@ void GetShuffledInput(const DeviceContext& dev_ctx,
GetShuffledDim(input.dims(), &shuffled_dims, dims, &perm_axis);
shuffled_input->ResizeAndAllocate(shuffled_dims);
shuffled_input->mutable_data<OutT>(dev_ctx.GetPlace());
dev_ctx.template Alloc<OutT>(shuffled_input);
pten::math::TransposeNormal<DeviceContext, OutT> trans;
trans(dev_ctx, input, shuffled_input, perm_axis);
......@@ -158,7 +158,7 @@ void ReduceKernelImpl(const DeviceContext& dev_ctx,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all) {
output->mutable_data<OutT>(dev_ctx.GetPlace());
dev_ctx.template Alloc<OutT>(output);
if (reduce_all) {
// Flatten and reduce 1-D tensor
......
......@@ -33,7 +33,7 @@ void ScaleKernel(const Context& dev_ctx,
bool bias_after_scale,
DenseTensor* out) {
// calc
out->mutable_data<T>(dev_ctx.GetPlace());
dev_ctx.template Alloc<T>(out);
auto eigen_out = pten::EigenVector<T>::Flatten(*out);
auto eigen_x = pten::EigenVector<T>::Flatten(x);
auto& dev = *dev_ctx.eigen_device();
......
......@@ -29,7 +29,7 @@ void EmptyKernel(const Context& dev_ctx,
template <typename T, typename Context>
void EmptyLikeKernel(const Context& dev_ctx, DenseTensor* out) {
out->mutable_data<T>(dev_ctx.GetPlace());
dev_ctx.template Alloc<T>(out);
}
} // namespace pten
......
......@@ -229,7 +229,7 @@ class TransformFunctor {
const bool is_xsize_larger = true)
: x_(x.data<T>()),
y_(y.data<T>()),
z_(z->mutable_data<OutType>(ctx.GetPlace())),
z_(ctx.template Alloc<OutType>(z)),
nx_(x.numel()),
ctx_(ctx),
func_(func),
......@@ -425,8 +425,8 @@ void ElemwiseGradComputeNoBroadcast(const DeviceContext &dev_ctx,
dout.data<Tout>(),
dx_op,
dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(dev_ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(dev_ctx.GetPlace())});
dx == nullptr ? nullptr : dev_ctx.template Alloc<T>(dx),
dy == nullptr ? nullptr : dev_ctx.template Alloc<T>(dy)});
}
inline void ElementwiseGradPreProcess(const DenseTensor &dout,
......@@ -631,7 +631,7 @@ void ElementwiseCudaKernel(const KPDevice &ctx,
ins_data[i] = ins[i]->data<InT>();
}
for (int i = 0; i < NumOuts; ++i) {
outs_data[i] = (*outs)[i]->mutable_data<OutT>(ctx.GetPlace());
outs_data[i] = ctx.Alloc<OutT>((*outs)[i]);
}
#ifdef PADDLE_WITH_XPU2
int block_size = 64;
......
......@@ -36,7 +36,7 @@ struct TransposeNormal<CPUContext, T> {
auto in_stride = pten::framework::stride(in.dims());
auto out_stride = pten::framework::stride(out->dims());
const T* in_ptr = in.data<T>();
T* out_ptr = out->mutable_data<T>(dev_ctx.GetPlace());
T* out_ptr = dev_ctx.template Alloc<T>(out);
auto transpose_helper = [&](int64_t beg, int64_t end) {
for (int64_t out_idx = beg; out_idx < end; ++out_idx) {
......
......@@ -61,7 +61,7 @@ struct TransposeNormal<GPUContext, T> {
auto in_stride = pten::framework::stride(in.dims());
auto out_stride = pten::framework::stride(out->dims());
auto* in_ptr = in.data<T>();
auto* out_ptr = out->mutable_data<T>(dev_ctx.GetPlace());
T* out_ptr = dev_ctx.template Alloc<T>(out);
// copy in_stride, out_stride, axis to gpu device
const paddle::platform::CUDAPlace& cuda_place = dev_ctx.GetPlace();
......
......@@ -43,7 +43,7 @@ void CastCUDAKernelImpl(const GPUContext& dev_ctx,
std::vector<DenseTensor*> outputs;
inputs.emplace_back(&x);
outputs.emplace_back(out);
out->mutable_data<OutT>(dev_ctx.GetPlace());
dev_ctx.Alloc<OutT>(out);
pten::funcs::LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kUnary,
InT,
OutT>(
......
......@@ -29,7 +29,7 @@ void DotKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
out->mutable_data<T>(dev_ctx.GetPlace());
dev_ctx.template Alloc<T>(out);
if (1 == out->dims().size()) {
auto eigen_out = pten::EigenScalar<T>::From(*out);
auto eigen_x = pten::EigenVector<T>::Flatten(x);
......
......@@ -352,7 +352,7 @@ void LaunchKernel(const KPDevice &ctx,
pten::framework::Array<_ptr_ OutT *, NumOuts> outs_data;
for (int i = 0; i < NumOuts; ++i) {
outs_data[i] = (*outs)[i]->mutable_data<OutT>(ctx.GetPlace());
outs_data[i] = ctx.Alloc<OutT>((*outs)[i]);
}
for (int i = 0; i < Arity; i++) {
......@@ -1264,8 +1264,8 @@ void CommonGradBroadcastCUDA(const DenseTensor &x,
const T *y_data = y.data<T>();
const Tout *out_data = out.data<Tout>();
const Tout *dout_data = dout.data<Tout>();
T *dx_data = dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace());
T *dy_data = dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace());
T *dx_data = dx == nullptr ? nullptr : ctx.Alloc<T>(dx);
T *dy_data = dy == nullptr ? nullptr : ctx.Alloc<T>(dy);
std::vector<int> x_one_indexs;
std::vector<int> y_one_indexs;
......@@ -1923,34 +1923,32 @@ void ElemwiseGradComputeWithBroadcast(const GPUContext &ctx,
return;
}
if (post == 1) {
ElemwiseGradBroadcast1CUDA(
ctx.stream(),
x.data<T>(),
y.data<T>(),
out.data<Tout>(),
dout.data<Tout>(),
pre,
n,
is_xsize_larger,
dx_op,
dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
ElemwiseGradBroadcast1CUDA(ctx.stream(),
x.data<T>(),
y.data<T>(),
out.data<Tout>(),
dout.data<Tout>(),
pre,
n,
is_xsize_larger,
dx_op,
dy_op,
dx == nullptr ? nullptr : ctx.Alloc<T>(dx),
dy == nullptr ? nullptr : ctx.Alloc<T>(dy));
} else {
ElemwiseGradBroadcast2CUDA(
ctx.stream(),
x.data<T>(),
y.data<T>(),
out.data<Tout>(),
dout.data<Tout>(),
pre,
n,
post,
is_xsize_larger,
dx_op,
dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
ElemwiseGradBroadcast2CUDA(ctx.stream(),
x.data<T>(),
y.data<T>(),
out.data<Tout>(),
dout.data<Tout>(),
pre,
n,
post,
is_xsize_larger,
dx_op,
dy_op,
dx == nullptr ? nullptr : ctx.Alloc<T>(dx),
dy == nullptr ? nullptr : ctx.Alloc<T>(dy));
}
}
......
......@@ -47,7 +47,7 @@ namespace pten {
inputs.emplace_back(&x); \
inputs.emplace_back(&y); \
outputs.emplace_back(out); \
out->mutable_data<T>(dev_ctx.GetPlace()); \
dev_ctx.template Alloc<T>(out); \
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>( \
dev_ctx, inputs, &outputs, axis, funcs::name##Functor<T>()); \
}
......
......@@ -54,7 +54,7 @@ void ScaleKernel(const Context& dev_ctx,
std::vector<DenseTensor*> outputs;
inputs.emplace_back(&x);
outputs.emplace_back(out);
out->mutable_data<T>(dev_ctx.GetPlace());
dev_ctx.template Alloc<T>(out);
pten::funcs::LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kUnary,
T,
T>(
......
......@@ -26,7 +26,7 @@ void ConjKernel(const Context& dev_ctx,
DenseTensor* out) {
auto numel = x.numel();
auto* x_data = x.data<T>();
auto* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
auto* out_data = dev_ctx.template Alloc<T>(out);
paddle::platform::ForRange<Context> for_range(dev_ctx, numel);
paddle::operators::math::ConjFunctor<T> functor(x_data, numel, out_data);
......
......@@ -73,7 +73,7 @@ struct DotGradFunction<DeviceContext,
auto dout = EigenMatrix<T>::From(*tensor_dout);
if (tensor_dx) {
tensor_dx->mutable_data<T>(ctx.GetPlace());
ctx.template Alloc<T>(tensor_dx);
auto y = EigenMatrix<T>::From(*tensor_y);
auto& dev = *ctx.eigen_device();
Eigen::DSizes<int, 2> size(1, tensor_dx->dims()[1]);
......@@ -85,7 +85,7 @@ struct DotGradFunction<DeviceContext,
}
if (tensor_dy) {
tensor_dy->mutable_data<T>(ctx.GetPlace());
ctx.template Alloc<T>(tensor_dy);
auto x = EigenMatrix<T>::From(*tensor_x);
auto& dev = *ctx.eigen_device();
Eigen::DSizes<int, 2> size(1, tensor_dy->dims()[1]);
......@@ -100,7 +100,7 @@ struct DotGradFunction<DeviceContext,
const auto* data_dout = tensor_dout->data<T>();
if (tensor_dx) {
auto* data_dx = tensor_dx->mutable_data<T>(ctx.GetPlace());
auto* data_dx = ctx.template Alloc<T>(tensor_dx);
const auto* data_y = tensor_y->data<T>();
const DDim& dim = tensor_x->dims();
size_t N = static_cast<size_t>(pten::framework::product(dim));
......@@ -115,7 +115,7 @@ struct DotGradFunction<DeviceContext,
}
if (tensor_dy) {
auto* data_dy = tensor_dy->mutable_data<T>(ctx.GetPlace());
auto* data_dy = ctx.template Alloc<T>(tensor_dy);
const auto* data_x = tensor_x->data<T>();
const DDim& dim = tensor_y->dims();
size_t N = static_cast<size_t>(pten::framework::product(dim));
......@@ -164,7 +164,7 @@ struct DotGradFunction<DeviceContext,
auto dout = EigenMatrix<T>::From(*tensor_dout);
if (tensor_dx) {
tensor_dx->mutable_data<T>(ctx.GetPlace());
ctx.template Alloc<T>(tensor_dx);
auto y = EigenMatrix<T>::From(*tensor_y);
auto dx = EigenMatrix<T>::From(*tensor_dx);
auto& dev = *ctx.eigen_device();
......@@ -173,7 +173,7 @@ struct DotGradFunction<DeviceContext,
}
if (tensor_dy) {
tensor_dy->mutable_data<T>(ctx.GetPlace());
ctx.template Alloc<T>(tensor_dy);
auto x = EigenMatrix<T>::From(*tensor_x);
auto dy = EigenMatrix<T>::From(*tensor_dy);
auto& dev = *ctx.eigen_device();
......@@ -189,7 +189,7 @@ struct DotGradFunction<DeviceContext,
auto const B = d[d.size() - 1];
if (tensor_dx) {
auto* dx = tensor_dx->mutable_data<T>(ctx.GetPlace());
auto* dx = ctx.template Alloc<T>(tensor_dx);
for (auto j = 0; j < N / B; ++j) {
auto const ss = dz[j];
for (auto i = 0; i < B; ++i) *dx++ = *y++ * ss;
......@@ -197,7 +197,7 @@ struct DotGradFunction<DeviceContext,
}
if (tensor_dy) {
auto* dy = tensor_dy->mutable_data<T>(ctx.GetPlace());
auto* dy = ctx.template Alloc<T>(tensor_dy);
for (auto j = 0; j < N / B; ++j) {
auto const ss = dz[j];
for (auto i = 0; i < B; i++) *dy++ = *x++ * ss;
......@@ -272,7 +272,7 @@ struct DotDoubleGradFunction<DeviceContext,
const auto* data_dout = tensor_dout->data<T>();
if (tensor_dx) {
auto* data_dx = tensor_dx->mutable_data<T>(ctx.GetPlace());
auto* data_dx = ctx.template Alloc<T>(tensor_dx);
const auto* data_ddy = tensor_ddy->data<T>();
const DDim& dim = tensor_dx->dims();
size_t N = static_cast<size_t>(product(dim));
......@@ -287,7 +287,7 @@ struct DotDoubleGradFunction<DeviceContext,
}
if (tensor_dy) {
auto* data_dy = tensor_dy->mutable_data<T>(ctx.GetPlace());
auto* data_dy = ctx.template Alloc<T>(tensor_dy);
const auto* data_ddx = tensor_ddx->data<T>();
const DDim& dim = tensor_dy->dims();
size_t N = static_cast<size_t>(product(dim));
......@@ -302,7 +302,7 @@ struct DotDoubleGradFunction<DeviceContext,
}
if (tensor_ddout) {
auto* data_ddout = tensor_ddout->mutable_data<T>(ctx.GetPlace());
auto* data_ddout = ctx.template Alloc<T>(tensor_ddout);
auto* data_x = tensor_x->data<T>();
auto* data_y = tensor_y->data<T>();
auto* data_ddx = tensor_ddx->data<T>();
......@@ -351,7 +351,7 @@ struct DotDoubleGradFunction<DeviceContext,
auto& dev = *ctx.eigen_device();
auto dout = EigenVector<T>::Flatten(*tensor_dout);
if (tensor_dx) {
tensor_dx->mutable_data<T>(ctx.GetPlace());
ctx.template Alloc<T>(tensor_dx);
auto ddy = EigenVector<T>::Flatten(*tensor_ddy);
Eigen::DSizes<int, 1> size(tensor_ddy->numel());
auto dx = EigenVector<T>::Flatten(*tensor_dx);
......@@ -359,7 +359,7 @@ struct DotDoubleGradFunction<DeviceContext,
}
if (tensor_dy) {
tensor_dy->mutable_data<T>(ctx.GetPlace());
ctx.template Alloc<T>(tensor_dy);
auto ddx = EigenVector<T>::Flatten(*tensor_ddx);
Eigen::DSizes<int, 1> size(tensor_ddx->numel());
......@@ -368,7 +368,7 @@ struct DotDoubleGradFunction<DeviceContext,
}
if (tensor_ddout) {
tensor_ddout->mutable_data<T>(ctx.GetPlace());
ctx.template Alloc<T>(tensor_ddout);
auto x = EigenVector<T>::Flatten(*tensor_x);
auto y = EigenVector<T>::Flatten(*tensor_y);
auto ddx = EigenVector<T>::Flatten(*tensor_ddx);
......@@ -381,7 +381,7 @@ struct DotDoubleGradFunction<DeviceContext,
const auto* data_dout = tensor_dout->data<T>();
if (tensor_dx) {
auto* data_dx = tensor_dx->mutable_data<T>(ctx.GetPlace());
auto* data_dx = ctx.template Alloc<T>(tensor_dx);
const auto* data_ddy = tensor_ddy->data<T>();
const DDim& dim = tensor_dx->dims();
size_t N = static_cast<size_t>(product(dim));
......@@ -396,7 +396,7 @@ struct DotDoubleGradFunction<DeviceContext,
}
if (tensor_dy) {
auto* data_dy = tensor_dy->mutable_data<T>(ctx.GetPlace());
auto* data_dy = ctx.template Alloc<T>(tensor_dy);
const auto* data_ddx = tensor_ddx->data<T>();
const DDim& dim = tensor_dy->dims();
size_t N = static_cast<size_t>(product(dim));
......@@ -411,7 +411,7 @@ struct DotDoubleGradFunction<DeviceContext,
}
if (tensor_ddout) {
auto* data_ddout = tensor_ddout->mutable_data<T>(ctx.GetPlace());
auto* data_ddout = ctx.template Alloc<T>(tensor_ddout);
auto* data_x = tensor_x->data<T>();
auto* data_y = tensor_y->data<T>();
auto* data_ddx = tensor_ddx->data<T>();
......@@ -552,7 +552,7 @@ struct DotTripleGradFunction<DeviceContext,
const auto* data_d_ddout = in_tensor_d_ddout->data<T>();
if (out_tensor_d_x) {
auto* data_d_x = out_tensor_d_x->mutable_data<T>(ctx.GetPlace());
auto* data_d_x = ctx.template Alloc<T>(out_tensor_d_x);
const auto* data_ddy = in_tensor_ddy->data<T>();
const DDim& dim = out_tensor_d_x->dims();
......@@ -567,7 +567,7 @@ struct DotTripleGradFunction<DeviceContext,
}
if (out_tensor_d_y) {
auto* data_d_y = out_tensor_d_y->mutable_data<T>(ctx.GetPlace());
auto* data_d_y = ctx.template Alloc<T>(out_tensor_d_y);
const auto* data_ddx = in_tensor_ddx->data<T>();
const DDim& dim = out_tensor_d_y->dims();
......@@ -582,7 +582,7 @@ struct DotTripleGradFunction<DeviceContext,
}
if (out_tensor_d_dout) {
auto* data_d_dout = out_tensor_d_dout->mutable_data<T>(ctx.GetPlace());
auto* data_d_dout = ctx.template Alloc<T>(out_tensor_d_dout);
auto* data_ddx = in_tensor_ddx->data<T>();
auto* data_ddy = in_tensor_ddy->data<T>();
auto* data_d_dx = in_tensor_d_dx->data<T>();
......@@ -613,7 +613,7 @@ struct DotTripleGradFunction<DeviceContext,
}
if (out_tensor_d_ddx) {
auto* data_d_ddx = out_tensor_d_ddx->mutable_data<T>(ctx.GetPlace());
auto* data_d_ddx = ctx.template Alloc<T>(out_tensor_d_ddx);
auto* data_dout = in_tensor_dout->data<T>();
auto* data_d_dy = in_tensor_d_dy->data<T>();
auto* data_y = in_tensor_y->data<T>();
......@@ -633,7 +633,7 @@ struct DotTripleGradFunction<DeviceContext,
}
if (out_tensor_d_ddy) {
auto* data_d_ddy = out_tensor_d_ddy->mutable_data<T>(ctx.GetPlace());
auto* data_d_ddy = ctx.template Alloc<T>(out_tensor_d_ddy);
auto* data_dout = in_tensor_dout->data<T>();
auto* data_d_dx = in_tensor_d_dx->data<T>();
auto* data_x = in_tensor_x->data<T>();
......@@ -678,7 +678,7 @@ struct DotTripleGradFunction<DeviceContext,
auto& dev = *ctx.eigen_device();
auto d_ddout = EigenVector<T>::Flatten(*in_tensor_d_ddout);
if (out_tensor_d_x) {
out_tensor_d_x->mutable_data<T>(ctx.GetPlace());
ctx.template Alloc<T>(out_tensor_d_x);
auto ddy = EigenVector<T>::Flatten(*in_tensor_ddy);
Eigen::DSizes<int, 1> size(in_tensor_ddy->numel());
auto d_x = EigenVector<T>::Flatten(*out_tensor_d_x);
......@@ -686,7 +686,7 @@ struct DotTripleGradFunction<DeviceContext,
}
if (out_tensor_d_y) {
out_tensor_d_y->mutable_data<T>(ctx.GetPlace());
ctx.template Alloc<T>(out_tensor_d_y);
auto ddx = EigenVector<T>::Flatten(*in_tensor_ddx);
Eigen::DSizes<int, 1> size(in_tensor_ddx->numel());
......@@ -695,7 +695,7 @@ struct DotTripleGradFunction<DeviceContext,
}
if (out_tensor_d_dout) {
out_tensor_d_dout->mutable_data<T>(ctx.GetPlace());
ctx.template Alloc<T>(out_tensor_d_dout);
auto ddx = EigenVector<T>::Flatten(*in_tensor_ddx);
auto ddy = EigenVector<T>::Flatten(*in_tensor_ddy);
auto d_dx = EigenVector<T>::Flatten(*in_tensor_d_dx);
......@@ -705,7 +705,7 @@ struct DotTripleGradFunction<DeviceContext,
}
if (out_tensor_d_ddx) {
out_tensor_d_ddx->mutable_data<T>(ctx.GetPlace());
ctx.template Alloc<T>(out_tensor_d_ddx);
auto dout = EigenVector<T>::Flatten(*in_tensor_dout);
auto y = EigenVector<T>::Flatten(*in_tensor_y);
auto d_ddout = EigenVector<T>::Flatten(*in_tensor_d_ddout);
......@@ -717,7 +717,7 @@ struct DotTripleGradFunction<DeviceContext,
}
if (out_tensor_d_ddy) {
out_tensor_d_ddy->mutable_data<T>(ctx.GetPlace());
ctx.template Alloc<T>(out_tensor_d_ddy);
auto dout = EigenVector<T>::Flatten(*in_tensor_dout);
auto x = EigenVector<T>::Flatten(*in_tensor_x);
auto d_ddout = EigenVector<T>::Flatten(*in_tensor_d_ddout);
......@@ -732,7 +732,7 @@ struct DotTripleGradFunction<DeviceContext,
const auto* data_d_ddout = in_tensor_d_ddout->data<T>();
if (out_tensor_d_x) {
auto* data_d_x = out_tensor_d_x->mutable_data<T>(ctx.GetPlace());
auto* data_d_x = ctx.template Alloc<T>(out_tensor_d_x);
const auto* data_ddy = in_tensor_ddy->data<T>();
const DDim& dim = out_tensor_d_x->dims();
......@@ -747,7 +747,7 @@ struct DotTripleGradFunction<DeviceContext,
}
if (out_tensor_d_y) {
auto* data_d_y = out_tensor_d_y->mutable_data<T>(ctx.GetPlace());
auto* data_d_y = ctx.template Alloc<T>(out_tensor_d_y);
const auto* data_ddx = in_tensor_ddx->data<T>();
const DDim& dim = out_tensor_d_y->dims();
......@@ -762,7 +762,7 @@ struct DotTripleGradFunction<DeviceContext,
}
if (out_tensor_d_dout) {
auto* data_d_dout = out_tensor_d_dout->mutable_data<T>(ctx.GetPlace());
auto* data_d_dout = ctx.template Alloc<T>(out_tensor_d_dout);
auto* data_ddx = in_tensor_ddx->data<T>();
auto* data_ddy = in_tensor_ddy->data<T>();
auto* data_d_dx = in_tensor_d_dx->data<T>();
......@@ -790,7 +790,7 @@ struct DotTripleGradFunction<DeviceContext,
}
if (out_tensor_d_ddx) {
auto* data_d_ddx = out_tensor_d_ddx->mutable_data<T>(ctx.GetPlace());
auto* data_d_ddx = ctx.template Alloc<T>(out_tensor_d_ddx);
auto* data_dout = in_tensor_dout->data<T>();
auto* data_d_dy = in_tensor_d_dy->data<T>();
auto* data_y = in_tensor_y->data<T>();
......@@ -809,7 +809,7 @@ struct DotTripleGradFunction<DeviceContext,
}
if (out_tensor_d_ddy) {
auto* data_d_ddy = out_tensor_d_ddy->mutable_data<T>(ctx.GetPlace());
auto* data_d_ddy = ctx.template Alloc<T>(out_tensor_d_ddy);
auto* data_dout = in_tensor_dout->data<T>();
auto* data_d_dx = in_tensor_d_dx->data<T>();
auto* data_x = in_tensor_x->data<T>();
......@@ -838,10 +838,10 @@ void DotGradKernel(const Context& dev_ctx,
DenseTensor* dx,
DenseTensor* dy) {
if (dx) {
dx->mutable_data<T>(dev_ctx.GetPlace());
dev_ctx.template Alloc<T>(dx);
}
if (dy) {
dy->mutable_data<T>(dev_ctx.GetPlace());
dev_ctx.template Alloc<T>(dy);
}
DotGradFunction<Context, T>()(dev_ctx, &x, &y, &dout, dx, dy);
}
......@@ -857,13 +857,13 @@ void DotDoubleGradKernel(const Context& dev_ctx,
DenseTensor* dy,
DenseTensor* ddout) {
if (dx) {
dx->mutable_data<T>(dev_ctx.GetPlace());
dev_ctx.template Alloc<T>(dx);
}
if (dy) {
dy->mutable_data<T>(dev_ctx.GetPlace());
dev_ctx.template Alloc<T>(dy);
}
if (ddout) {
ddout->mutable_data<T>(dev_ctx.GetPlace());
dev_ctx.template Alloc<T>(ddout);
}
DotDoubleGradFunction<Context, T>()(
dev_ctx, &x, &y, &dout, ddx, ddy, dx, dy, ddout);
......@@ -885,19 +885,19 @@ void DotTripleGradKernel(const Context& dev_ctx,
DenseTensor* d_ddy,
DenseTensor* d_dout) {
if (d_x) {
d_x->mutable_data<T>(dev_ctx.GetPlace());
dev_ctx.template Alloc<T>(d_x);
}
if (d_y) {
d_y->mutable_data<T>(dev_ctx.GetPlace());
dev_ctx.template Alloc<T>(d_y);
}
if (d_ddx) {
d_ddx->mutable_data<T>(dev_ctx.GetPlace());
dev_ctx.template Alloc<T>(d_ddx);
}
if (d_ddy) {
d_ddy->mutable_data<T>(dev_ctx.GetPlace());
dev_ctx.template Alloc<T>(d_ddy);
}
if (d_dout) {
d_dout->mutable_data<T>(dev_ctx.GetPlace());
dev_ctx.template Alloc<T>(d_dout);
}
DotTripleGradFunction<Context, T>()(dev_ctx,
......
......@@ -26,7 +26,7 @@ namespace pten {
template <typename T, typename Context, typename VType>
void FullValue(const Context& dev_ctx, DenseTensor* tensor, VType val) {
tensor->mutable_data<T>(dev_ctx.GetPlace());
dev_ctx.template Alloc<T>(tensor);
auto t = pten::EigenVector<T>::Flatten(*tensor);
t.device(*dev_ctx.eigen_device()) = t.constant(static_cast<T>(val));
}
......
......@@ -105,7 +105,7 @@ void MatMul(const Context& dev_ctx,
bool trans_b,
DenseTensor* out,
bool flag = false) {
out->mutable_data<T>(dev_ctx.GetPlace());
dev_ctx.template Alloc<T>(out);
auto blas = paddle::operators::math::GetBlas<Context, T>(dev_ctx);
auto mat_dim_a =
paddle::operators::math::CreateMatrixDescriptor(a.dims(), 0, trans_a);
......@@ -123,7 +123,7 @@ void MatMul(const Context& dev_ctx,
b.data<T>(),
mat_dim_b,
static_cast<T>(1),
out->data<T>(),
dev_ctx.template Alloc<T>(out),
static_cast<T>(flag));
}
......@@ -242,8 +242,8 @@ void MatmulGradKernel(const Context& dev_ctx,
// Case1 : x's or y's dim = 1
if (x_ndim == 1 && y_ndim == 1) {
if (dx) dx->mutable_data<T>(dev_ctx.GetPlace());
if (dy) dy->mutable_data<T>(dev_ctx.GetPlace());
if (dx) dev_ctx.template Alloc<T>(dx);
if (dy) dev_ctx.template Alloc<T>(dy);
if (out_grad.numel() == 1) {
DotGradFunction<Context, T>()(dev_ctx, &x, &y, &out_grad, dx, dy);
return;
......
......@@ -118,7 +118,7 @@ void MatMulFunction(const Context& dev_ctx,
N));
VLOG(3) << "MatMul's case 1";
Out->Resize({1});
Out->mutable_data<T>(dev_ctx.GetPlace());
dev_ctx.template Alloc<T>(Out);
blas.GEMM(CblasNoTrans,
CblasTrans,
1,
......@@ -128,7 +128,7 @@ void MatMulFunction(const Context& dev_ctx,
y_data,
x_data,
static_cast<T>(flag),
Out->data<T>());
dev_ctx.template Alloc<T>(Out));
return;
}
......@@ -165,7 +165,7 @@ void MatMulFunction(const Context& dev_ctx,
out_dims.back() = y_dims.back();
}
Out->ResizeAndAllocate(pten::framework::make_ddim(out_dims));
Out->mutable_data<T>(dev_ctx.GetPlace());
dev_ctx.template Alloc<T>(Out);
if (trans_y) {
const int M = Y.numel() / N;
VLOG(3) << "MatMul's case 2";
......@@ -176,7 +176,7 @@ void MatMulFunction(const Context& dev_ctx,
y_data,
x_data,
static_cast<T>(flag),
Out->data<T>());
dev_ctx.template Alloc<T>(Out));
} else {
const int M = y_dims[y_ndim - 1];
const int batch_size = Y.numel() / (M * N);
......@@ -189,7 +189,7 @@ void MatMulFunction(const Context& dev_ctx,
y_data,
x_data,
static_cast<T>(flag),
Out->data<T>());
dev_ctx.template Alloc<T>(Out));
} else {
VLOG(3) << "MatMul's case 4";
blas.BatchedGEMM(CblasTrans,
......@@ -201,7 +201,7 @@ void MatMulFunction(const Context& dev_ctx,
y_data,
x_data,
static_cast<T>(flag),
Out->data<T>(),
dev_ctx.template Alloc<T>(Out),
batch_size,
M * N,
0);
......@@ -243,7 +243,7 @@ void MatMulFunction(const Context& dev_ctx,
std::copy_n(x_dims.cbegin(), x_ndim - 1, out_dims.begin());
}
Out->ResizeAndAllocate(pten::framework::make_ddim(out_dims));
Out->mutable_data<T>(dev_ctx.GetPlace());
dev_ctx.template Alloc<T>(Out);
if (trans_x) {
const int M = x_dims[x_ndim - 1];
......@@ -257,7 +257,7 @@ void MatMulFunction(const Context& dev_ctx,
x_data,
y_data,
static_cast<T>(flag),
Out->data<T>());
dev_ctx.template Alloc<T>(Out));
} else {
VLOG(3) << "MatMul's case 6";
blas.BatchedGEMM(CblasTrans,
......@@ -269,7 +269,7 @@ void MatMulFunction(const Context& dev_ctx,
x_data,
y_data,
static_cast<T>(flag),
Out->data<T>(),
dev_ctx.template Alloc<T>(Out),
batch_size,
M * N,
0);
......@@ -284,7 +284,7 @@ void MatMulFunction(const Context& dev_ctx,
x_data,
y_data,
static_cast<T>(flag),
Out->data<T>());
dev_ctx.template Alloc<T>(Out));
}
return;
}
......@@ -331,7 +331,7 @@ void MatMulFunction(const Context& dev_ctx,
out_broadcast_dims[ndim - 1] = N;
Out->ResizeAndAllocate(pten::framework::make_ddim(out_broadcast_dims));
Out->mutable_data<T>(dev_ctx.GetPlace());
dev_ctx.template Alloc<T>(Out);
const int batch_dim = ndim - 2;
// broadcast message
......@@ -367,7 +367,7 @@ void MatMulFunction(const Context& dev_ctx,
x_data,
y_data,
static_cast<T>(flag),
Out->data<T>());
dev_ctx.template Alloc<T>(Out));
} else if (x_batch_size == 1) {
if (M == 1 && trans_y) {
VLOG(3) << "MatMul's case 9";
......@@ -378,7 +378,7 @@ void MatMulFunction(const Context& dev_ctx,
y_data,
x_data,
static_cast<T>(flag),
Out->data<T>());
dev_ctx.template Alloc<T>(Out));
} else {
VLOG(3) << "MatMul's case 10";
blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans,
......@@ -390,7 +390,7 @@ void MatMulFunction(const Context& dev_ctx,
x_data,
y_data,
static_cast<T>(flag),
Out->data<T>(),
dev_ctx.template Alloc<T>(Out),
out_batch_size,
0,
K * N);
......@@ -407,7 +407,7 @@ void MatMulFunction(const Context& dev_ctx,
x_data,
y_data,
static_cast<T>(flag),
Out->data<T>());
dev_ctx.template Alloc<T>(Out));
} else {
VLOG(3) << "MatMul's case 12";
blas.BatchedGEMM(CblasTrans,
......@@ -419,7 +419,7 @@ void MatMulFunction(const Context& dev_ctx,
x_data,
y_data,
static_cast<T>(flag),
Out->data<T>(),
dev_ctx.template Alloc<T>(Out),
out_batch_size,
M * K,
0);
......@@ -435,7 +435,7 @@ void MatMulFunction(const Context& dev_ctx,
x_data,
y_data,
static_cast<T>(flag),
Out->data<T>(),
dev_ctx.template Alloc<T>(Out),
out_batch_size,
M * K,
K * N);
......@@ -454,7 +454,7 @@ void MatMulFunction(const Context& dev_ctx,
x_ptr[i] = x_data + x_index * M * K;
y_ptr[i] = y_data + y_index * K * N;
out_ptr[i] = Out->data<T>() + i * M * N;
out_ptr[i] = dev_ctx.template Alloc<T>(Out) + i * M * N;
IndexIncreaseFromDims(batch_dim, out_broadcast_dims.data(), index.data());
}
VLOG(3) << "MatMul's case 14";
......
......@@ -26,7 +26,7 @@ template <typename T, typename Context>
void SignKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
out->mutable_data<T>(dev_ctx.GetPlace());
dev_ctx.template Alloc<T>(out);
auto eigen_out = pten::EigenVector<T>::Flatten(*out);
auto eigen_x = pten::EigenVector<T>::Flatten(x);
......
......@@ -32,7 +32,7 @@ void ReshapeKernel(const Context& dev_ctx,
return;
}
out->set_meta(out_meta);
out->mutable_data(dev_ctx.GetPlace());
dev_ctx.Alloc(out);
pten::Copy(dev_ctx, x, false, out);
out->Resize(out_meta.dims);
out->ResetLoD(x.lod());
......
......@@ -30,7 +30,7 @@ void Copy(const Context& dev_ctx,
bool blocking,
DenseTensor* dst) {
auto* src_ptr = src.data();
auto* dst_ptr = dst->mutable_data(dev_ctx.GetPlace());
auto* dst_ptr = dev_ctx.Alloc(dst);
const auto& src_place = src.place();
const auto& dst_place = dst->place();
......
......@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/kernels/cast_kernel.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/common/data_type.h"
#include "paddle/pten/core/dense_tensor.h"
......@@ -48,6 +49,11 @@ TEST(DEV_API, cast) {
}
pten::CPUContext dev_ctx;
dev_ctx.SetDeviceAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
pten::DataType out_dtype = pten::DataType::FLOAT64;
// 2. test API
auto out = pten::Cast<float>(dev_ctx, dense_x, out_dtype);
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/pten/kernels/concat_kernel.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
......@@ -56,6 +57,10 @@ TEST(DEV_API, concat) {
// 2. test API
pten::CPUContext dev_ctx;
dev_ctx.SetDeviceAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
auto out = pten::Concat<float>(dev_ctx, inputs, 0);
// 3. check result
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/kernels/complex_kernel.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
......@@ -44,6 +45,10 @@ TEST(DEV_API, conj) {
}
pten::CPUContext dev_ctx;
dev_ctx.SetDeviceAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
// 2. test API
auto out = pten::Conj<paddle::complex64>(dev_ctx, dense_x);
......
......@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/copy_kernel.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/core/dense_tensor.h"
......@@ -57,6 +58,10 @@ TEST(DEV_API, copy) {
std::cout << typeid(a).name() << std::endl;
// 2. test API
pten::CPUContext dev_ctx;
dev_ctx.SetDeviceAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
pten::Copy(dev_ctx, *(dense_src.get()), false, dense_dst.get());
// 3. check result
......
......@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/pten/kernels/empty_kernel.h"
#include "paddle/pten/kernels/full_kernel.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
......@@ -32,6 +33,10 @@ using DDim = pten::framework::DDim;
TEST(DEV_API, empty) {
// 1. create input
pten::CPUContext dev_ctx;
dev_ctx.SetDeviceAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
// 2. test API
auto out = pten::Empty<float>(dev_ctx, {3, 2}, pten::DataType::INT32);
......@@ -58,6 +63,10 @@ TEST(DEV_API, empty_like) {
// 2. test API
pten::CPUContext dev_ctx;
dev_ctx.SetDeviceAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
auto out = pten::EmptyLike<float>(dev_ctx, dense_x);
// 3. check result
......@@ -74,6 +83,10 @@ TEST(DEV_API, full) {
// 2. test API
pten::CPUContext dev_ctx;
dev_ctx.SetDeviceAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
auto out = pten::Full<float>(dev_ctx, {3, 2}, val, pten::DataType::FLOAT32);
// 3. check result
......@@ -103,6 +116,10 @@ TEST(DEV_API, full_like) {
float val = 1.0;
pten::CPUContext dev_ctx;
dev_ctx.SetDeviceAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
// 2. test API
auto out = pten::FullLike<float>(dev_ctx, dense_x, val);
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/kernels/dot_kernel.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
......@@ -57,6 +58,10 @@ TEST(DEV_API, dot) {
// 2. test API
pten::CPUContext dev_ctx;
dev_ctx.SetDeviceAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
auto out = pten::Dot<float>(dev_ctx, dense_x, dense_y);
// 3. check result
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/kernels/math_kernel.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
......@@ -59,6 +60,10 @@ TEST(DEV_API, add) {
// 2. test API
pten::CPUContext dev_ctx;
dev_ctx.SetDeviceAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
auto dense_out = pten::Add<float>(dev_ctx, dense_x, dense_y);
// 3. check result
......@@ -107,6 +112,10 @@ TEST(DEV_API, subtract) {
// 2. test API
pten::CPUContext dev_ctx;
dev_ctx.SetDeviceAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
auto dense_out = pten::Subtract<float>(dev_ctx, dense_x, dense_y);
// 3. check result
......@@ -155,6 +164,10 @@ TEST(DEV_API, divide) {
// 2. test API
pten::CPUContext dev_ctx;
dev_ctx.SetDeviceAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
auto dense_out = pten::Divide<float>(dev_ctx, dense_x, dense_y);
// 3. check result
......@@ -203,6 +216,10 @@ TEST(DEV_API, multiply) {
// 2. test API
pten::CPUContext dev_ctx;
dev_ctx.SetDeviceAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
auto dense_out = pten::Multiply<float>(dev_ctx, dense_x, dense_y);
// 3. check result
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/kernels/flatten_kernel.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
......@@ -55,6 +56,10 @@ TEST(DEV_API, flatten) {
}
int start_axis = 1, stop_axis = 2;
pten::CPUContext dev_ctx;
dev_ctx.SetDeviceAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
// 2. test API
auto out = pten::Flatten<float>(dev_ctx, dense_x, start_axis, stop_axis);
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/pten/kernels/matmul_kernel.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
......@@ -54,6 +55,10 @@ TEST(DEV_API, dot) {
// 2. test API
pten::CPUContext dev_ctx;
dev_ctx.SetDeviceAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
auto out = Matmul<float, CPUContext>(dev_ctx, dense_x, dense_y, false, false);
// 3. check result
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/pten/kernels/math_kernel.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
......@@ -47,6 +48,10 @@ TEST(DEV_API, mean) {
// 2. test API
pten::CPUContext dev_ctx;
dev_ctx.SetDeviceAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
auto out = pten::Mean<float>(dev_ctx, dense_x, dims, false);
// 3. check result
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/pten/kernels/reshape_kernel.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
......@@ -47,6 +48,10 @@ TEST(DEV_API, reshape) {
// 2. test API
pten::CPUContext dev_ctx;
dev_ctx.SetDeviceAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
auto out = pten::Reshape<float>(dev_ctx, dense_x, shape);
// 3. check result
std::vector<int64_t> expect_shape = {12, 3};
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/pten/kernels/scale_kernel.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
......@@ -47,6 +48,10 @@ TEST(DEV_API, scale) {
// 2. test API
pten::CPUContext dev_ctx;
dev_ctx.SetDeviceAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
auto out =
pten::Scale<float>(dev_ctx, dense_x, scale, bias, bias_after_scale);
......@@ -85,6 +90,10 @@ TEST(DEV_API, scale_host) {
// 2. test API
pten::CPUContext dev_ctx;
dev_ctx.SetDeviceAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
auto out =
pten::Scale<float>(dev_ctx, dense_x, scale, bias, bias_after_scale);
......
......@@ -17,10 +17,10 @@ limitations under the License. */
#include "paddle/pten/kernels/math_kernel.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
namespace pten {
namespace tests {
......@@ -46,6 +46,10 @@ TEST(DEV_API, sum) {
std::vector<int64_t> axis = {0, 1};
pten::CPUContext dev_ctx;
dev_ctx.SetDeviceAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
// 2. test API
auto out =
pten::Sum<float>(dev_ctx, dense_x, axis, pten::DataType::FLOAT32, false);
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import os
os.environ['FLAGS_use_stream_safe_cuda_allocator'] = "true"
import sys
import unittest
import paddle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册