未验证 提交 064bc4b8 编写于 作者: W Wilber 提交者: GitHub

[PTEN] Add cpu context (#38979)

* add cpu_context.

* update

* update

* update

* update

* update

* fix ci problem

* fix npu ci problem

* update

* fix ci compile
上级 08793179
......@@ -42,9 +42,6 @@ namespace framework {
class Scope;
class Variable;
} // namespace framework
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
......
......@@ -37,9 +37,6 @@ namespace paddle {
namespace framework {
class Scope;
} // namespace framework
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
......
......@@ -48,9 +48,6 @@ class Executor;
class ProgramDesc;
class Scope;
} // namespace framework
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
DECLARE_double(eager_delete_tensor_gb);
......
......@@ -48,9 +48,6 @@ class Executor;
class ProgramDesc;
class Scope;
} // namespace framework
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
......
......@@ -33,31 +33,39 @@ static void ScaleDeviceDispatch(const pten::DenseTensor& dense_tensor,
pten::DenseTensor* dense_out) {
switch (dense_tensor.dtype()) {
case pten::DataType::FLOAT64: {
pten::ScaleKernel<double, DeviceContext>(
dev_ctx, dense_tensor /* tensor */, scale /* scale */,
bias /* bias */, bias_after_scale /* bias_after_scale */,
dense_out /* out tensor */);
pten::ScaleKernel<double, typename paddle::framework::
ConvertToPtenContext<DeviceContext>::TYPE>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
dense_tensor /* tensor */, scale /* scale */, bias /* bias */,
bias_after_scale /* bias_after_scale */, dense_out /* out tensor */);
break;
}
case pten::DataType::FLOAT32: {
pten::ScaleKernel<float, DeviceContext>(
dev_ctx, dense_tensor /* tensor */, scale /* scale */,
bias /* bias */, bias_after_scale /* bias_after_scale */,
dense_out /* out tensor */);
pten::ScaleKernel<float, typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
dense_tensor /* tensor */, scale /* scale */, bias /* bias */,
bias_after_scale /* bias_after_scale */, dense_out /* out tensor */);
break;
}
case pten::DataType::INT64: {
pten::ScaleKernel<int64_t, DeviceContext>(
dev_ctx, dense_tensor /* tensor */, scale /* scale */,
bias /* bias */, bias_after_scale /* bias_after_scale */,
dense_out /* out tensor */);
pten::ScaleKernel<int64_t, typename paddle::framework::
ConvertToPtenContext<DeviceContext>::TYPE>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
dense_tensor /* tensor */, scale /* scale */, bias /* bias */,
bias_after_scale /* bias_after_scale */, dense_out /* out tensor */);
break;
}
case pten::DataType::INT32: {
pten::ScaleKernel<int32_t, DeviceContext>(
dev_ctx, dense_tensor /* tensor */, scale /* scale */,
bias /* bias */, bias_after_scale /* bias_after_scale */,
dense_out /* out tensor */);
pten::ScaleKernel<int32_t, typename paddle::framework::
ConvertToPtenContext<DeviceContext>::TYPE>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
dense_tensor /* tensor */, scale /* scale */, bias /* bias */,
bias_after_scale /* bias_after_scale */, dense_out /* out tensor */);
break;
}
default: {
......
......@@ -31,9 +31,6 @@ namespace paddle {
namespace framework {
class Variable;
} // namespace framework
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace pten {
......
......@@ -18,12 +18,6 @@
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace framework {
namespace details {
......
......@@ -33,10 +33,6 @@ namespace ir {
class Node;
} // namespace ir
} // namespace framework
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
......
......@@ -28,9 +28,6 @@ namespace ir {
class Node;
} // namespace ir
} // namespace framework
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
......
......@@ -25,12 +25,6 @@
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle {
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace framework {
......
......@@ -28,9 +28,6 @@ namespace ir {
class Node;
} // namespace ir
} // namespace framework
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
......
......@@ -46,9 +46,6 @@ namespace framework {
class ProgramDesc;
class Scope;
} // namespace framework
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
......
......@@ -26,12 +26,6 @@
#include "paddle/fluid/platform/device/mlu/device_context.h"
#endif
namespace paddle {
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace framework {
......
......@@ -18,12 +18,6 @@ limitations under the License. */
#include "paddle/fluid/framework/version.h"
namespace paddle {
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace framework {
......
......@@ -27,12 +27,6 @@ limitations under the License. */
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace framework {
......
......@@ -75,5 +75,16 @@ class KernelArgsNameMaker {
void SetAllocationForOutputTenosr(pten::DenseTensor* tensor,
const platform::Place& place);
// TODO(Wilber): support others device context.
template <typename T>
struct ConvertToPtenContext {
using TYPE = T;
};
template <>
struct ConvertToPtenContext<platform::CPUDeviceContext> {
using TYPE = pten::CPUContext;
};
} // namespace framework
} // namespace paddle
......@@ -14,12 +14,6 @@ limitations under the License. */
#include "paddle/fluid/framework/selected_rows.h"
namespace paddle {
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace framework {
......
......@@ -16,17 +16,13 @@
#include <string>
#include <vector>
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace platform {
class DeviceContext;
} // namespace platform
namespace framework {
class Variable;
} // namespace framework
} // namespace paddle
namespace paddle {
......
......@@ -173,6 +173,11 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
<< " | kernel key: " << pt_kernel_key
<< " | kernel: " << pt_kernel;
if (platform::is_cpu_place(expected_kernel_key.place_)) {
auto* cpu_ctx = pool.Get(paddle::platform::CPUPlace());
return PreparedOp(op, ctx, expected_kernel_key, pt_kernel_signature,
pt_kernel, cpu_ctx);
}
// TODO(chenweihang): using CPUKernel when miss device kernel case
return PreparedOp(op, ctx, expected_kernel_key, pt_kernel_signature,
pt_kernel, dev_ctx);
......
......@@ -37,9 +37,6 @@ namespace paddle {
namespace framework {
class Variable;
} // namespace framework
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
......
......@@ -33,11 +33,6 @@
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace platform {
class DeviceContext;
} // namespace platform
namespace imperative {
class ParallelContext;
class VarBase;
......
......@@ -19,13 +19,9 @@ limitations under the License. */
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/stream/stream.h"
#include "paddle/pten/core/device_context.h"
namespace paddle {
namespace platform {
class DeviceContext;
} // platform
namespace memory {
using pten::Allocation;
......@@ -37,7 +33,7 @@ extern std::shared_ptr<Allocation> AllocShared(const platform::Place& place,
extern AllocationPtr Alloc(const platform::Place& place, size_t size);
extern AllocationPtr Alloc(const platform::DeviceContext& dev_ctx, size_t size);
extern AllocationPtr Alloc(const pten::DeviceContext& dev_ctx, size_t size);
extern uint64_t Release(const platform::Place& place);
......
......@@ -19,12 +19,6 @@ limitations under the License. */
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace pten {
class DenseTensor;
} // namespace pten
......
......@@ -67,7 +67,10 @@ class CastOpKernel : public framework::OpKernel<InT> {
static_cast<framework::proto::VarType::Type>(out_dtype));
// call new kernel
pten::CastKernel<InT>(dev_ctx, *in, pt_out_dtype, out);
pten::CastKernel<InT>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
*in, pt_out_dtype, out);
}
};
......
......@@ -202,7 +202,10 @@ class CholeskySolveGradKernel : public framework::OpKernel<T> {
commonterm_for_range(commonterm_functor);
commonterm_conj = helper.Transpose(commonterm_conj);
pten::AddKernel<T>(dev_ctx, commonterm, commonterm_conj, -1, &commonterm);
pten::AddKernel<T>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE &>(dev_ctx),
commonterm, commonterm_conj, -1, &commonterm);
auto mat_dim_u = math::CreateMatrixDescriptor(u_bst.dims(), 0, false);
auto mat_dim_c =
......
......@@ -36,7 +36,10 @@ class ConjKernel : public framework::OpKernel<T> {
auto& dev_ctx = context.device_context<DeviceContext>();
// call new kernel
pten::ConjKernel<T>(dev_ctx, *x, out);
pten::ConjKernel<T>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
*x, out);
}
};
......
......@@ -41,7 +41,11 @@ class DotKernel : public framework::OpKernel<T> {
out->mutable_data<T>(x->place());
// call new kernel
pten::DotKernel<T, DeviceContext>(dev_ctx, *x, *y, out);
pten::DotKernel<T, typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
*x, *y, out);
}
};
......@@ -61,8 +65,10 @@ class DotGradKernel : public framework::OpKernel<T> {
auto& dev_ctx = ctx.device_context<DeviceContext>();
// call new kernel
pten::DotGradKernel<T>(dev_ctx, *tensor_x, *tensor_y, *tensor_dout,
tensor_dx, tensor_dy);
pten::DotGradKernel<T>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
*tensor_x, *tensor_y, *tensor_dout, tensor_dx, tensor_dy);
}
};
......
......@@ -61,7 +61,10 @@ class ElementwiseAddKernel : public framework::OpKernel<T> {
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*y);
auto pt_z = paddle::experimental::MakePtenDenseTensor(*z);
pten::AddKernel<T>(dev_ctx, *pt_x.get(), *pt_y.get(), axis, pt_z.get());
pten::AddKernel<T>(
static_cast<const typename framework::ConvertToPtenContext<
DeviceContext>::TYPE &>(dev_ctx),
*pt_x.get(), *pt_y.get(), axis, pt_z.get());
}
};
......
......@@ -51,7 +51,10 @@ class ElementwiseDivKernel : public framework::OpKernel<T> {
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*y);
auto pt_z = paddle::experimental::MakePtenDenseTensor(*z);
pten::DivideKernel<T>(dev_ctx, *pt_x.get(), *pt_y.get(), axis, pt_z.get());
pten::DivideKernel<T>(
static_cast<const typename framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
*pt_x.get(), *pt_y.get(), axis, pt_z.get());
}
};
......
......@@ -124,8 +124,10 @@ class ElementwiseMulKernel : public framework::OpKernel<T> {
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x_lod);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*y);
auto pt_z = paddle::experimental::MakePtenDenseTensor(*z_lod);
pten::MultiplyKernel<T>(dev_ctx, *pt_x.get(), *pt_y.get(), axis,
pt_z.get());
pten::MultiplyKernel<T>(
static_cast<const typename framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
*pt_x.get(), *pt_y.get(), axis, pt_z.get());
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"X's type[%s] is not supported by elementwise_op. X's type should be "
......
......@@ -51,8 +51,10 @@ class ElementwiseSubKernel : public framework::OpKernel<T> {
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*y);
auto pt_z = paddle::experimental::MakePtenDenseTensor(*z);
pten::SubtractKernel<T>(dev_ctx, *pt_x.get(), *pt_y.get(), axis,
pt_z.get());
pten::SubtractKernel<T>(
static_cast<const typename framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
*pt_x.get(), *pt_y.get(), axis, pt_z.get());
}
};
......
......@@ -62,7 +62,10 @@ class FillAnyLikeKernel : public framework::OpKernel<T> {
const auto& dev_ctx = context.template device_context<DeviceContext>();
// call new kernel
pten::FullLikeKernel<T>(dev_ctx, value, out);
pten::FullLikeKernel<T>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
value, out);
}
};
......
......@@ -57,9 +57,9 @@ class FillConstantBatchSizeLikeOpKernel : public framework::OpKernel<T> {
}
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(ctx.GetPlace());
bool cpu_place = force_cpu || ctx.GetPlace() == platform::CPUPlace();
if (cpu_place) {
auto &dev_ctx = *pool.Get(platform::CPUPlace());
math::SetConstant<platform::CPUDeviceContext, T> functor;
out->mutable_data(platform::CPUPlace(), data_type);
functor(reinterpret_cast<const platform::CPUDeviceContext &>(dev_ctx),
......@@ -67,6 +67,7 @@ class FillConstantBatchSizeLikeOpKernel : public framework::OpKernel<T> {
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (!cpu_place) {
auto &dev_ctx = *pool.Get(ctx.GetPlace());
math::SetConstant<platform::CUDADeviceContext, T> functor;
out->mutable_data(ctx.GetPlace(), data_type);
functor(reinterpret_cast<const platform::CUDADeviceContext &>(dev_ctx),
......
......@@ -67,9 +67,9 @@ class FillConstantBatchSizeLikeOpNPUKernel : public framework::OpKernel<T> {
}
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(ctx.GetPlace());
bool cpu_place = force_cpu || ctx.GetPlace() == platform::CPUPlace();
if (cpu_place) {
auto &dev_ctx = *pool.Get(platform::CPUPlace());
math::SetConstant<platform::CPUDeviceContext, T> functor;
out->mutable_data(platform::CPUPlace(), data_type);
functor(reinterpret_cast<const platform::CPUDeviceContext &>(dev_ctx),
......
......@@ -102,7 +102,6 @@ class FillConstantKernel : public framework::OpKernel<T> {
}
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(ctx.GetPlace());
int actual_place = place_type;
if (actual_place == -1) {
......@@ -123,12 +122,14 @@ class FillConstantKernel : public framework::OpKernel<T> {
: "<T>");
tensor->mutable_data(platform::CPUPlace(), data_type);
math::SetConstant<platform::CPUDeviceContext, T> functor;
auto &dev_ctx = *pool.Get(platform::CPUPlace());
functor(reinterpret_cast<const platform::CPUDeviceContext &>(dev_ctx),
tensor, static_cast<T>(value));
} else if (actual_place == 1) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
tensor->mutable_data(ctx.GetPlace(), data_type);
math::SetConstant<platform::CUDADeviceContext, T> functor;
auto &dev_ctx = *pool.Get(ctx.GetPlace());
functor(reinterpret_cast<const platform::CUDADeviceContext &>(dev_ctx),
tensor, static_cast<T>(value));
#else
......@@ -138,8 +139,10 @@ class FillConstantKernel : public framework::OpKernel<T> {
} else if (actual_place == 2) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
tensor->mutable_data(platform::CUDAPinnedPlace(), data_type);
math::SetConstant<platform::CPUDeviceContext, T> functor;
functor(reinterpret_cast<const platform::CPUDeviceContext &>(dev_ctx),
math::SetConstant<platform::CUDAPinnedDeviceContext, T> functor;
auto &dev_ctx = *pool.Get(platform::CUDAPinnedPlace());
functor(
reinterpret_cast<const platform::CUDAPinnedDeviceContext &>(dev_ctx),
tensor, static_cast<T>(value));
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
......@@ -149,6 +152,7 @@ class FillConstantKernel : public framework::OpKernel<T> {
#ifdef PADDLE_WITH_XPU
tensor->mutable_data(ctx.GetPlace(), data_type);
math::SetConstant<platform::XPUDeviceContext, T> functor;
auto &dev_ctx = *pool.Get(ctx.GetPlace());
functor(reinterpret_cast<const platform::XPUDeviceContext &>(dev_ctx),
tensor, static_cast<T>(value));
#else
......
......@@ -132,8 +132,11 @@ class FlattenContiguousRangeKernel : public framework::OpKernel<T> {
auto &dev_ctx = context.device_context<DeviceContext>();
// call new kernel
pten::FlattenKernel<T, DeviceContext>(dev_ctx, *in, start_axis, stop_axis,
out);
pten::FlattenKernel<T, typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE &>(dev_ctx),
*in, start_axis, stop_axis, out);
}
};
......@@ -150,7 +153,11 @@ class FlattenContiguousRangeGradKernel : public framework::OpKernel<T> {
auto &dev_ctx = ctx.device_context<DeviceContext>();
// call new kernel
pten::FlattenGradKernel<T, DeviceContext>(dev_ctx, *d_out, *xshape, d_x);
pten::FlattenGradKernel<T, typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE &>(dev_ctx),
*d_out, *xshape, d_x);
}
};
......
......@@ -31,7 +31,6 @@ namespace paddle {
namespace platform {
class CPUDeviceContext;
class CUDADeviceContext;
class DeviceContext;
} // namespace platform
} // namespace paddle
......
......@@ -221,7 +221,11 @@ void Tensor_Add(const DeviceContext& dev_ctx, const framework::Tensor& src1,
out->Resize(src1.dims());
out->mutable_data<T>(dev_ctx.GetPlace());
pten::AddKernel<T, DeviceContext>(dev_ctx, src1, src2, -1, out);
pten::AddKernel<
T, typename paddle::framework::ConvertToPtenContext<DeviceContext>::TYPE>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
src1, src2, -1, out);
}
template <typename DeviceContext, typename T>
......@@ -230,7 +234,11 @@ void Tensor_Sub(const DeviceContext& dev_ctx, const framework::Tensor& src1,
out->Resize(src1.dims());
out->mutable_data<T>(dev_ctx.GetPlace());
pten::SubtractKernel<T, DeviceContext>(dev_ctx, src1, src2, -1, out);
pten::SubtractKernel<
T, typename paddle::framework::ConvertToPtenContext<DeviceContext>::TYPE>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
src1, src2, -1, out);
}
template <typename DeviceContext, typename T, size_t D>
......
......@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/pten/backends/cpu/cpu_context.h"
#ifdef PADDLE_WITH_MKLML
#include <mkl.h>
#endif
......@@ -819,6 +820,12 @@ T *Blas<platform::CPUDeviceContext>::GEMM_ALLOC(const CBLAS_IDENTIFIER id,
const int K) const {
return CBlas<T>::GEMM_ALLOC(id, M, N, K);
}
template <>
template <typename T>
T *Blas<pten::CPUContext>::GEMM_ALLOC(const CBLAS_IDENTIFIER id, const int M,
const int N, const int K) const {
return CBlas<T>::GEMM_ALLOC(id, M, N, K);
}
template <>
template <typename T>
......@@ -829,6 +836,15 @@ void Blas<platform::CPUDeviceContext>::GEMM_PACK(const CBLAS_IDENTIFIER id,
const int ld, T *dst) const {
CBlas<T>::GEMM_PACK(CblasRowMajor, id, trans, M, N, K, alpha, src, ld, dst);
}
template <>
template <typename T>
void Blas<pten::CPUContext>::GEMM_PACK(const CBLAS_IDENTIFIER id,
const CBLAS_TRANSPOSE trans, int M,
int N, int K, const T alpha,
const T *src, const int ld,
T *dst) const {
CBlas<T>::GEMM_PACK(CblasRowMajor, id, trans, M, N, K, alpha, src, ld, dst);
}
template <>
template <typename T>
......@@ -838,12 +854,26 @@ void Blas<platform::CPUDeviceContext>::GEMM_COMPUTE(
CBlas<T>::GEMM_COMPUTE(CblasRowMajor, transA, transB, M, N, K, A, lda, B, ldb,
beta, C, ldc);
}
template <>
template <typename T>
void Blas<pten::CPUContext>::GEMM_COMPUTE(int transA, int transB, int M, int N,
int K, const T *A, const int lda,
const T *B, const int ldb, T beta,
T *C, const int ldc) const {
CBlas<T>::GEMM_COMPUTE(CblasRowMajor, transA, transB, M, N, K, A, lda, B, ldb,
beta, C, ldc);
}
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::GEMM_FREE(T *data) const {
CBlas<T>::GEMM_FREE(data);
}
template <>
template <typename T>
void Blas<pten::CPUContext>::GEMM_FREE(T *data) const {
CBlas<T>::GEMM_FREE(data);
}
#endif
template <>
......@@ -858,6 +888,18 @@ void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
CBlas<T>::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc);
}
template <>
template <typename T>
void Blas<pten::CPUContext>::GEMM(CBLAS_TRANSPOSE transA,
CBLAS_TRANSPOSE transB, int M, int N, int K,
T alpha, const T *A, const T *B, T beta,
T *C) const {
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
CBlas<T>::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc);
}
template <>
template <typename T>
......@@ -869,6 +911,15 @@ void Blas<platform::CPUDeviceContext>::GEMM(bool transA, bool transB, int M,
transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A,
lda, B, ldb, beta, C, ldc);
}
template <>
template <typename T>
void Blas<pten::CPUContext>::GEMM(bool transA, bool transB, int M, int N, int K,
T alpha, const T *A, int lda, const T *B,
int ldb, T beta, T *C, int ldc) const {
CBlas<T>::GEMM(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans,
transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A,
lda, B, ldb, beta, C, ldc);
}
template <>
template <typename T>
......@@ -880,6 +931,15 @@ void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
CBlas<T>::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc);
}
template <>
template <typename T>
void Blas<pten::CPUContext>::GEMM(CBLAS_TRANSPOSE transA,
CBLAS_TRANSPOSE transB, int M, int N, int K,
T alpha, const T *A, int lda, const T *B,
int ldb, T beta, T *C, int ldc) const {
CBlas<T>::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc);
}
template <typename DeviceContext>
template <typename T>
......@@ -920,12 +980,22 @@ void Blas<platform::CPUDeviceContext>::AXPY(int n, T alpha, const T *x,
T *y) const {
CBlas<T>::AXPY(n, alpha, x, 1, y, 1);
}
template <>
template <typename T>
void Blas<pten::CPUContext>::AXPY(int n, T alpha, const T *x, T *y) const {
CBlas<T>::AXPY(n, alpha, x, 1, y, 1);
}
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::VCOPY(int n, const T *x, T *y) const {
CBlas<T>::VCOPY(n, x, 1, y, 1);
}
template <>
template <typename T>
void Blas<pten::CPUContext>::VCOPY(int n, const T *x, T *y) const {
CBlas<T>::VCOPY(n, x, 1, y, 1);
}
template <>
template <typename T>
......@@ -942,6 +1012,20 @@ void Blas<platform::CPUDeviceContext>::VADD(int n, const T *x, const T *y,
}
#endif
}
template <>
template <typename T>
void Blas<pten::CPUContext>::VADD(int n, const T *x, const T *y, T *z) const {
#ifdef PADDLE_WITH_MKLML
CBlas<T>::VADD(n, x, y, z);
#else
if (x == z) {
this->template AXPY<T>(n, (T)(1.), y, z);
} else {
this->template VCOPY<T>(n, y, z);
this->template AXPY<T>(n, (T)(1.), x, z);
}
#endif
}
template <>
template <typename T>
......@@ -956,6 +1040,18 @@ void Blas<platform::CPUDeviceContext>::VSUB(int n, const T *x, const T *y,
}
#endif
}
template <>
template <typename T>
void Blas<pten::CPUContext>::VSUB(int n, const T *x, const T *y, T *z) const {
#ifdef PADDLE_WITH_MKLML
CBlas<T>::VSUB(n, x, y, z);
#else
// try to find if openblas support vsub
for (int i = 0; i < n; ++i) {
z[i] = x[i] - y[i];
}
#endif
}
template <>
template <typename T>
......@@ -970,6 +1066,18 @@ void Blas<platform::CPUDeviceContext>::VMUL(int n, const T *x, const T *y,
}
#endif
}
template <>
template <typename T>
void Blas<pten::CPUContext>::VMUL(int n, const T *x, const T *y, T *z) const {
#ifdef PADDLE_WITH_MKLML
CBlas<T>::VMUL(n, x, y, z);
#else
// try to find if openblas support vmul
for (int i = 0; i < n; ++i) {
z[i] = x[i] * y[i];
}
#endif
}
template <>
template <typename T>
......@@ -984,6 +1092,18 @@ void Blas<platform::CPUDeviceContext>::VDIV(int n, const T *x, const T *y,
}
#endif
}
template <>
template <typename T>
void Blas<pten::CPUContext>::VDIV(int n, const T *x, const T *y, T *z) const {
#ifdef PADDLE_WITH_MKLML
CBlas<T>::VDIV(n, x, y, z);
#else
// try to find if openblas support vdiv
for (int i = 0; i < n; ++i) {
z[i] = x[i] / y[i];
}
#endif
}
template <>
template <typename T>
......@@ -997,6 +1117,18 @@ void Blas<platform::CPUDeviceContext>::VEXP(int n, const T *x, T *y) const {
}
#endif
}
template <>
template <typename T>
void Blas<pten::CPUContext>::VEXP(int n, const T *x, T *y) const {
#ifdef PADDLE_WITH_MKLML
CBlas<T>::VEXP(n, x, y);
#else
// try to find if openblas support vexp
for (int i = 0; i < n; ++i) {
y[i] = std::exp(x[i]);
}
#endif
}
template <>
template <typename T>
......@@ -1009,6 +1141,17 @@ void Blas<platform::CPUDeviceContext>::VSQUARE(int n, const T *x, T *y) const {
}
#endif
}
template <>
template <typename T>
void Blas<pten::CPUContext>::VSQUARE(int n, const T *x, T *y) const {
#ifdef PADDLE_WITH_MKLML
CBlas<T>::VSQUARE(n, x, y);
#else
for (int i = 0; i < n; ++i) {
y[i] = x[i] * x[i];
}
#endif
}
template <>
template <typename T>
......@@ -1022,6 +1165,17 @@ void Blas<platform::CPUDeviceContext>::VPOW(int n, const T *x, T a,
}
#endif
}
template <>
template <typename T>
void Blas<pten::CPUContext>::VPOW(int n, const T *x, T a, T *y) const {
#ifdef PADDLE_WITH_MKLML
CBlas<T>::VPOW(n, x, a, y);
#else
for (int i = 0; i < n; ++i) {
y[i] = std::pow(x[i], a);
}
#endif
}
template <>
template <typename T>
......@@ -1037,6 +1191,20 @@ T Blas<platform::CPUDeviceContext>::DOT(int n, const T *x, const T *y) const {
return sum;
#endif
}
template <>
template <typename T>
T Blas<pten::CPUContext>::DOT(int n, const T *x, const T *y) const {
#ifdef PADDLE_WITH_MKLML
return CBlas<T>::DOT(n, x, 1, y, 1);
#else
// try to find if openblas support cblas_dot
T sum = 0;
for (int i = 0; i < n; ++i) {
sum += x[i] * y[i];
}
return sum;
#endif
}
template <>
template <typename T>
......@@ -1050,6 +1218,18 @@ void Blas<platform::CPUDeviceContext>::SCAL(int n, const T a, T *x) const {
}
#endif
}
template <>
template <typename T>
void Blas<pten::CPUContext>::SCAL(int n, const T a, T *x) const {
#ifdef PADDLE_WITH_MKLML
CBlas<T>::SCAL(n, a, x, 1);
#else
// try to find if openblas support cblas_scal
for (int i = 0; i < n; ++i) {
x[i] = a * x[i];
}
#endif
}
template <>
template <typename T>
......@@ -1065,6 +1245,20 @@ T Blas<platform::CPUDeviceContext>::ASUM(int n, T *x, int inc) const {
#endif
return sum;
}
template <>
template <typename T>
T Blas<pten::CPUContext>::ASUM(int n, T *x, int inc) const {
auto sum = static_cast<T>(0.0);
#ifdef PADDLE_WITH_MKLML
sum = CBlas<T>::ASUM(n, x, inc);
#else
// TODO(jczaja): check if openblas does provide cblas_sasum/cblas_dasum
for (int c = 0; c < n; ++c) {
sum += x[c];
}
#endif
return sum;
}
template <>
template <typename T>
......@@ -1074,6 +1268,13 @@ void Blas<platform::CPUDeviceContext>::GEMV(bool trans_a, int M, int N, T alpha,
CBLAS_TRANSPOSE transA = !trans_a ? CblasNoTrans : CblasTrans;
CBlas<T>::GEMV(CblasRowMajor, transA, M, N, alpha, A, N, B, 1, beta, C, 1);
}
template <>
template <typename T>
void Blas<pten::CPUContext>::GEMV(bool trans_a, int M, int N, T alpha,
const T *A, const T *B, T beta, T *C) const {
CBLAS_TRANSPOSE transA = !trans_a ? CblasNoTrans : CblasTrans;
CBlas<T>::GEMV(CblasRowMajor, transA, M, N, alpha, A, N, B, 1, beta, C, 1);
}
template <>
template <typename T>
......@@ -1112,6 +1313,45 @@ void Blas<platform::CPUDeviceContext>::BatchedGEMM(
}
#endif
}
template <>
template <typename T>
void Blas<pten::CPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
CBLAS_TRANSPOSE transB, int M, int N,
int K, T alpha, const T *A, const T *B,
T beta, T *C, int batchCount,
int64_t strideA,
int64_t strideB) const {
PADDLE_ENFORCE_NOT_NULL(
A, platform::errors::InvalidArgument("Pointer A should not be null."));
PADDLE_ENFORCE_NOT_NULL(
B, platform::errors::InvalidArgument("Pointer B should not be null."));
PADDLE_ENFORCE_NOT_NULL(
C, platform::errors::InvalidArgument("Pointer C should not be null."));
#ifdef PADDLE_WITH_MKLML
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
auto a_array = std::vector<const T *>(batchCount);
auto b_array = std::vector<const T *>(batchCount);
auto c_array = std::vector<T *>(batchCount);
for (int k = 0; k < batchCount; ++k) {
a_array[k] = &A[k * strideA];
b_array[k] = &B[k * strideB];
c_array[k] = &C[k * M * N];
}
CBlas<T>::GEMM_BATCH(CblasRowMajor, &transA, &transB, &M, &N, &K, &alpha,
a_array.data(), &lda, b_array.data(), &ldb, &beta,
c_array.data(), &ldc, 1 /* group_count */, &batchCount);
#else
for (int k = 0; k < batchCount; ++k) {
auto *Ak = &A[k * strideA];
auto *Bk = &B[k * strideB];
auto *Ck = &C[k * M * N];
this->template GEMM<T>(transA, transB, M, N, K, alpha, Ak, Bk, beta, Ck);
}
#endif
}
template <>
template <typename T>
......@@ -1132,6 +1372,27 @@ void Blas<platform::CPUDeviceContext>::BatchedGEMM(
}
#endif
}
template <>
template <typename T>
void Blas<pten::CPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
CBLAS_TRANSPOSE transB, int M, int N,
int K, T alpha, const T **A,
const T **B, T beta, T **C,
int batchCount) const {
#ifdef PADDLE_WITH_MKLML
const int lda = (std::max)((transA == CblasNoTrans) ? K : M, 1);
const int ldb = (std::max)((transB == CblasNoTrans) ? N : K, 1);
const int ldc = (std::max)(N, 1);
CBlas<T>::GEMM_BATCH(CblasRowMajor, &transA, &transB, &M, &N, &K, &alpha, A,
&lda, B, &ldb, &beta, C, &ldc, 1 /* group_count */,
&batchCount);
#else
for (int k = 0; k < batchCount; ++k) {
this->template GEMM<T>(transA, transB, M, N, K, alpha, A[k], B[k], beta,
C[k]);
}
#endif
}
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) && \
!defined(PADDLE_WITH_HIP) // @{ Group Blas MKLML: BatchedGEMMWithHead
......@@ -1204,6 +1465,75 @@ void Blas<platform::CPUDeviceContext>::BatchedGEMMWithHead(
}
}
}
template <>
template <typename T>
void Blas<pten::CPUContext>::BatchedGEMMWithHead(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int W1, int H1, int W2,
int H2, T alpha, const T *A, const T *B, T beta, T *C, int batchCount,
int64_t strideA, int64_t strideB, int64_t head_number,
bool split_b_vertical) const {
int lda = (transA == CblasNoTrans) ? W1 : H1;
int ldb = (transB == CblasNoTrans) ? W2 : H2;
auto a_array = std::vector<const T *>(batchCount);
auto b_array = std::vector<const T *>(batchCount);
auto c_array = std::vector<T *>(batchCount);
if (split_b_vertical) {
int ldc = W2;
int sub_width = W2 / head_number;
for (int i = 0; i < head_number; i++) {
int sub_matA_offset = (transA == CblasNoTrans)
? i * (W1 / head_number)
: i * (W1 / head_number) * H1;
int sub_matB_offset = (transB == CblasNoTrans)
? i * (W2 / head_number)
: i * (W2 / head_number) * H2;
int sub_matC_offset = i * W2 / head_number;
for (int k = 0; k < batchCount; ++k) {
a_array[k] = &A[k * strideA] + sub_matA_offset;
b_array[k] = &B[k * strideB] + sub_matB_offset;
c_array[k] = &C[k * H1 * W2] + sub_matC_offset;
}
CBlas<T>::GEMM_BATCH(CblasRowMajor, &transA, &transB, &H1, &sub_width,
&H2, &alpha, a_array.data(), &lda, b_array.data(),
&ldb, &beta, c_array.data(), &ldc,
1 /* group_count */, &batchCount);
}
} else {
PADDLE_ENFORCE_EQ(
W1, H2,
platform::errors::InvalidArgument(
"The fisrt matrix width should be same as second matrix height,"
"but received fisrt matrix width %d"
", second matrix height %d",
W1, H2));
int ldc = W2 * head_number;
int sub_width = W1 / head_number;
for (int i = 0; i < head_number; i++) {
int sub_matA_offset = (transA == CblasNoTrans)
? i * (W1 / head_number)
: i * (W1 / head_number) * H1;
int sub_matB_offset = (transB == CblasNoTrans)
? i * (W1 / head_number) * W2
: i * (W1 / head_number);
int sub_matC_offset = i * W2;
for (int k = 0; k < batchCount; ++k) {
a_array[k] = &A[k * strideA] + sub_matA_offset;
b_array[k] = &B[k * strideB] + sub_matB_offset;
c_array[k] = &C[k * H1 * head_number * W2] + sub_matC_offset;
}
CBlas<T>::GEMM_BATCH(CblasRowMajor, &transA, &transB, &H1, &W2,
&sub_width, &alpha, a_array.data(), &lda,
b_array.data(), &ldb, &beta, c_array.data(), &ldc,
1 /* group_count */, &batchCount);
}
}
}
#endif // @} End Group Blas MKLML: BatchedGEMMWithHead
template <typename DeviceContext>
......@@ -1241,6 +1571,31 @@ void Blas<platform::CPUDeviceContext>::MatMul(const int M, const int N,
CBlas<T>::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K,
static_cast<T>(1), A, K, B, N, static_cast<T>(0), C, N);
}
template <>
template <typename T>
void Blas<pten::CPUContext>::MatMul(const int M, const int N, const int K,
const T *A, const T *B, T *C) const {
#ifdef PADDLE_WITH_LIBXSMM
// Refer to https://github.com/hfp/libxsmm/blob/master/README.md
// But the threshold is custom constexpr int LIBXSMM_THRESHOLD = 20 * 20 * 20;
// Since the matrix is very small,
// so the unit of calculation is already very fast,
// and the if( M*N*K < LIBXSMM_THRESHOLD) would be overhead,
// use xsmm directly.
// Note: SMM use ColMajor
const char transa = 'N';
const char transb = 'N';
const T alpha = static_cast<T>(1);
const T beta = static_cast<T>(0);
CBlas<T>::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &N, A, &K, &beta,
C, &N);
return;
#endif
CBlas<T>::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K,
static_cast<T>(1), A, K, B, N, static_cast<T>(0), C, N);
}
template <typename DeviceContext>
template <typename T>
......@@ -1443,6 +1798,18 @@ void Blas<platform::CPUDeviceContext>::VMERF(int n, const T *a, T *y,
}
#endif
}
template <>
template <typename T>
void Blas<pten::CPUContext>::VMERF(int n, const T *a, T *y,
int64_t mode) const {
#ifdef PADDLE_WITH_MKLML
CBlas<T>::VMERF(n, a, y, mode);
#else
for (int i = 0; i < n; ++i) {
y[i] = std::erf(a[i]);
}
#endif
}
#ifdef PADDLE_WITH_MKLML
template <>
......@@ -1455,6 +1822,17 @@ void Blas<platform::CPUDeviceContext>::CSRMM(
CBlas<T>::CSRMM(transa, m, n, k, alpha, matdescra, val, indx, pntrb, pntre, b,
ldb, beta, c, ldc);
}
template <>
template <typename T>
void Blas<pten::CPUContext>::CSRMM(const char *transa, const int *m,
const int *n, const int *k, const T *alpha,
const char *matdescra, const T *val,
const int *indx, const int *pntrb,
const int *pntre, const T *b, const int *ldb,
const T *beta, T *c, const int *ldc) const {
CBlas<T>::CSRMM(transa, m, n, k, alpha, matdescra, val, indx, pntrb, pntre, b,
ldb, beta, c, ldc);
}
#endif
template <>
......@@ -1467,6 +1845,15 @@ void Blas<platform::CPUDeviceContext>::TRSM(CBLAS_SIDE side, CBLAS_UPLO uplo,
CBlas<T>::TRSM(CblasRowMajor, side, uplo, transA, diag, M, N, alpha, A, lda,
B, ldb);
}
template <>
template <typename T>
void Blas<pten::CPUContext>::TRSM(CBLAS_SIDE side, CBLAS_UPLO uplo,
CBLAS_TRANSPOSE transA, CBLAS_DIAG diag,
int M, int N, T alpha, const T *A, int lda,
T *B, int ldb) const {
CBlas<T>::TRSM(CblasRowMajor, side, uplo, transA, diag, M, N, alpha, A, lda,
B, ldb);
}
} // namespace math
} // namespace operators
......
......@@ -44,6 +44,22 @@ template struct SetConstant<platform::CUDADeviceContext,
template struct SetConstant<platform::CUDADeviceContext,
platform::complex<double>>;
template struct SetConstant<platform::CUDAPinnedDeviceContext,
platform::float16>;
template struct SetConstant<platform::CUDAPinnedDeviceContext,
platform::bfloat16>;
template struct SetConstant<platform::CUDAPinnedDeviceContext, float>;
template struct SetConstant<platform::CUDAPinnedDeviceContext, double>;
template struct SetConstant<platform::CUDAPinnedDeviceContext, uint8_t>;
template struct SetConstant<platform::CUDAPinnedDeviceContext, int>;
template struct SetConstant<platform::CUDAPinnedDeviceContext, int16_t>;
template struct SetConstant<platform::CUDAPinnedDeviceContext, int64_t>;
template struct SetConstant<platform::CUDAPinnedDeviceContext, bool>;
template struct SetConstant<platform::CUDAPinnedDeviceContext,
platform::complex<float>>;
template struct SetConstant<platform::CUDAPinnedDeviceContext,
platform::complex<double>>;
#define DEFINE_GPU_TRANS(RANK) \
template struct Transpose<platform::CUDADeviceContext, bool, RANK>; \
template struct Transpose<platform::CUDADeviceContext, float, RANK>; \
......
......@@ -53,7 +53,10 @@ class MatMulV2Kernel : public framework::OpKernel<T> {
Out->mutable_data<T>(X->place());
// call new kernel
pten::MatmulKernel<T>(dev_ctx, *X, *Y, trans_x, trans_y, Out);
pten::MatmulKernel<T>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
*X, *Y, trans_x, trans_y, Out);
}
};
......@@ -149,8 +152,10 @@ class MatMulV2GradKernel : public framework::OpKernel<T> {
auto& dev_ctx = ctx.device_context<DeviceContext>();
// call new kernel
pten::MatmulGradKernel<T>(dev_ctx, *x, *y, *dout, transpose_x, transpose_y,
dx, dy);
pten::MatmulGradKernel<T>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
*x, *y, *dout, transpose_x, transpose_y, dx, dy);
}
};
......@@ -178,8 +183,10 @@ class MatMulV2DoubleGradKernel : public framework::OpKernel<T> {
auto& dev_ctx = context.device_context<DeviceContext>();
// call new kernel
pten::MatmulDoubleGradKernel<T>(dev_ctx, *x, *y, *dout, *ddx, *ddy,
transpose_x, transpose_y, dx, dy, ddout);
pten::MatmulDoubleGradKernel<T>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
*x, *y, *dout, *ddx, *ddy, transpose_x, transpose_y, dx, dy, ddout);
}
};
......@@ -218,7 +225,9 @@ class MatMulV2TripleGradKernel : public framework::OpKernel<T> {
auto& dev_ctx = context.device_context<DeviceContext>();
// call new kernel
pten::MatmulTripleGradKernel<T>(
dev_ctx, *x, *y, *dout, *ddx, *ddy, *d_dx, *d_dy, *d_ddout, transpose_x,
static_cast<const typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
*x, *y, *dout, *ddx, *ddy, *d_dx, *d_dy, *d_ddout, transpose_x,
transpose_y, out_d_x, out_d_y, out_d_dout, out_d_ddx, out_d_ddy);
}
};
......
......@@ -16,12 +16,6 @@ limitations under the License. */
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace pten {
class DenseTensor;
} // namespace pten
......
......@@ -17,12 +17,6 @@ limitations under the License. */
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/stream/stream.h"
namespace paddle {
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace pten {
class DenseTensor;
} // namespace pten
......
......@@ -19,12 +19,6 @@ limitations under the License. */
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace pten {
class DenseTensor;
} // namespace pten
......
......@@ -45,9 +45,6 @@ class Executor;
class ProgramDesc;
class Scope;
} // namespace framework
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
......
......@@ -22,12 +22,6 @@ limitations under the License. */
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace operators {
......
......@@ -258,8 +258,11 @@ class ReduceKernel : public framework::OpKernel<T> {
std::vector<int64_t> tmp_dims(dims.begin(), dims.end());
// call new kernel
pten::Reduce<DeviceContext, T, Functor>(
dev_ctx, *pt_x.get(), reduce_all, tmp_dims, keep_dim,
pten::Reduce<typename framework::ConvertToPtenContext<DeviceContext>::TYPE,
T, Functor>(
static_cast<const typename framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
*pt_x.get(), reduce_all, tmp_dims, keep_dim,
pten::TransToPtenDataType(cast_out_dtype), pt_out.get());
}
};
......
......@@ -19,6 +19,7 @@ limitations under the License. */
// only can include the headers in paddle/pten/api dirs
#include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/kernels/reshape_grad_kernel.h"
#include "paddle/pten/kernels/reshape_kernel.h"
......@@ -435,7 +436,8 @@ class ReshapeKernel {
}
if (platform::is_cpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CPUDeviceContext>();
pten::ReshapeKernel(dev_ctx, *pt_x.get(), pt_scalar_shape, pt_out);
pten::ReshapeKernel(static_cast<const pten::CPUContext &>(dev_ctx),
*pt_x.get(), pt_scalar_shape, pt_out);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(ctx.GetPlace())) {
......@@ -471,7 +473,8 @@ class ReshapeGradKernel {
if (platform::is_cpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CPUDeviceContext>();
pten::ReshapeGradKernel(dev_ctx, *pt_d_out.get(), pt_d_x.get());
pten::ReshapeGradKernel(static_cast<const pten::CPUContext &>(dev_ctx),
*pt_d_out.get(), pt_d_x.get());
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(ctx.GetPlace())) {
......@@ -500,7 +503,9 @@ class ReshapeDoubleGradKernel {
if (platform::is_cpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CPUDeviceContext>();
pten::ReshapeDoubleGradKernel(dev_ctx, *pt_dd_x.get(), pt_dd_out.get());
pten::ReshapeDoubleGradKernel(
static_cast<const pten::CPUContext &>(dev_ctx), *pt_dd_x.get(),
pt_dd_out.get());
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(ctx.GetPlace())) {
......
......@@ -67,7 +67,10 @@ class ScaleKernel : public framework::OpKernel<T> {
auto& dev_ctx = ctx.device_context<DeviceContext>();
// call new kernel
pten::ScaleKernel<T>(dev_ctx, *in, scale, bias, bias_after_scale, out);
pten::ScaleKernel<T>(
static_cast<const typename framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
*in, scale, bias, bias_after_scale, out);
}
};
......
......@@ -30,7 +30,7 @@ class GPUSeedKernel : public framework::OpKernel<T> {
if (cpu_place) {
platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(context.GetPlace());
auto &dev_ctx = *pool.Get(platform::CPUPlace());
out->mutable_data<T>(platform::CPUPlace());
math::SetConstant<platform::CPUDeviceContext, T> functor;
functor(reinterpret_cast<const platform::CPUDeviceContext &>(dev_ctx),
......
......@@ -35,7 +35,11 @@ class SignKernel : public framework::OpKernel<T> {
out->mutable_data<T>(x->place());
// call new kernel
pten::SignKernel<T, DeviceContext>(dev_ctx, *x, out);
pten::SignKernel<T, typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
*x, out);
}
};
......
......@@ -21,12 +21,6 @@
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace pten {
class DenseTensor;
} // namespace pten
......
......@@ -123,7 +123,7 @@ cc_library(cudnn_workspace_helper SRCS cudnn_workspace_helper.cc DEPS boost)
# avoiding cycle dependencies
cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool malloc xxhash ${STREAM_CALLBACK_DEPS}
place pten_place eigen3 stringpiece cpu_helper cpu_info framework_proto ${IPU_CTX_DEPS} ${GPU_CTX_DEPS} ${NPU_CTX_DEPS} ${MKLDNN_CTX_DEPS}
${dgc_deps} dlpack cudnn_workspace_helper ${XPU_CTX_DEPS} ${MLU_CTX_DEPS})
${dgc_deps} dlpack cudnn_workspace_helper ${XPU_CTX_DEPS} ${MLU_CTX_DEPS} cpu_context)
cc_library(collective_helper SRCS collective_helper.cc gen_comm_id_helper.cc DEPS framework_proto device_context enforce)
if(WITH_ASCEND_CL)
......
......@@ -21,8 +21,6 @@ struct DefaultDevice;
struct GpuDevice;
} // namespace Eigen
// class DeviceContext;
namespace paddle {
namespace platform {
......
......@@ -132,7 +132,7 @@ platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
return it->second.get().get();
}
template <typename DevCtx, typename PlaceType>
template <typename DevCtx>
inline void EmplaceDeviceContext(
std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
map_ptr,
......@@ -158,19 +158,14 @@ DeviceContextPool::DeviceContextPool(
}
for (auto& p : set) {
if (platform::is_cpu_place(p)) {
platform::CPUPlace place;
#ifdef PADDLE_WITH_MKLDNN
EmplaceDeviceContext<MKLDNNDeviceContext, CPUPlace>(&device_contexts_,
place);
EmplaceDeviceContext<MKLDNNDeviceContext>(&device_contexts_, p);
#else
EmplaceDeviceContext<CPUDeviceContext, CPUPlace>(&device_contexts_,
place);
EmplaceDeviceContext<CPUDeviceContext>(&device_contexts_, p);
#endif
} else if (platform::is_gpu_place(p)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
platform::CUDAPlace place(p.GetDeviceId());
EmplaceDeviceContext<CUDADeviceContext, CUDAPlace>(&device_contexts_,
place);
EmplaceDeviceContext<CUDADeviceContext>(&device_contexts_, p);
#else
PADDLE_THROW(
platform::errors::Unimplemented("CUDAPlace is not supported. Please "
......@@ -178,9 +173,7 @@ DeviceContextPool::DeviceContextPool(
#endif
} else if (platform::is_cuda_pinned_place(p)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
platform::CUDAPinnedPlace place;
EmplaceDeviceContext<CUDAPinnedDeviceContext, CUDAPinnedPlace>(
&device_contexts_, place);
EmplaceDeviceContext<CUDAPinnedDeviceContext>(&device_contexts_, p);
#else
PADDLE_THROW(platform::errors::Unimplemented(
"CUDAPlace is not supported. Please re-compile with WITH_GPU "
......@@ -188,9 +181,7 @@ DeviceContextPool::DeviceContextPool(
#endif
} else if (platform::is_xpu_place(p)) {
#ifdef PADDLE_WITH_XPU
platform::XPUPlace place(p.GetDeviceId());
EmplaceDeviceContext<XPUDeviceContext, XPUPlace>(&device_contexts_,
place);
EmplaceDeviceContext<XPUDeviceContext>(&device_contexts_, p);
#else
PADDLE_THROW(
platform::errors::Unimplemented("XPUPlace is not supported. Please "
......@@ -198,9 +189,7 @@ DeviceContextPool::DeviceContextPool(
#endif
} else if (platform::is_mlu_place(p)) {
#ifdef PADDLE_WITH_MLU
platform::MLUPlace place(p.GetDeviceId());
EmplaceDeviceContext<MLUDeviceContext, MLUPlace>(&device_contexts_,
place);
EmplaceDeviceContext<MLUDeviceContext>(&device_contexts_, p);
#else
PADDLE_THROW(
platform::errors::Unimplemented("MLUPlace is not supported. Please "
......@@ -208,9 +197,7 @@ DeviceContextPool::DeviceContextPool(
#endif
} else if (platform::is_ipu_place(p)) {
#ifdef PADDLE_WITH_IPU
platform::IPUPlace place(p.GetDeviceId());
EmplaceDeviceContext<IPUDeviceContext, IPUPlace>(&device_contexts_,
place);
EmplaceDeviceContext<IPUDeviceContext>(&device_contexts_, p);
#else
PADDLE_THROW(
platform::errors::Unimplemented("IPUPlace is not supported. Please "
......@@ -218,9 +205,7 @@ DeviceContextPool::DeviceContextPool(
#endif
} else if (platform::is_npu_place(p)) {
#ifdef PADDLE_WITH_ASCEND_CL
platform::NPUPlace place(p.GetDeviceId());
EmplaceDeviceContext<NPUDeviceContext, NPUPlace>(&device_contexts_,
place);
EmplaceDeviceContext<NPUDeviceContext>(&device_contexts_, p);
#else
PADDLE_THROW(platform::errors::Unimplemented(
"NPUPlace is not supported. Please "
......@@ -228,9 +213,7 @@ DeviceContextPool::DeviceContextPool(
#endif
} else if (platform::is_npu_pinned_place(p)) {
#ifdef PADDLE_WITH_ASCEND_CL
platform::NPUPinnedPlace place;
EmplaceDeviceContext<NPUPinnedDeviceContext, NPUPinnedPlace>(
&device_contexts_, place);
EmplaceDeviceContext<NPUPinnedDeviceContext>(&device_contexts_, p);
#else
PADDLE_THROW(platform::errors::Unimplemented(
"NPUPinnedPlace is not supported. Please re-compile with "
......@@ -241,19 +224,9 @@ DeviceContextPool::DeviceContextPool(
}
}
CPUDeviceContext::CPUDeviceContext() {
eigen_device_.reset(new Eigen::DefaultDevice());
}
CPUDeviceContext::CPUDeviceContext(CPUPlace place) : place_(place) {
eigen_device_.reset(new Eigen::DefaultDevice());
}
Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const {
return eigen_device_.get();
}
CPUDeviceContext::CPUDeviceContext() : pten::CPUContext() {}
Place CPUDeviceContext::GetPlace() const { return place_; }
CPUDeviceContext::CPUDeviceContext(CPUPlace place) : pten::CPUContext() {}
#ifdef PADDLE_WITH_IPU
IPUDeviceContext::IPUDeviceContext(IPUPlace place) : place_(place) {
......
......@@ -18,6 +18,9 @@ limitations under the License. */
#include <utility>
#include <vector>
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/core/device_context.h"
#include "paddle/fluid/memory/malloc.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/device/gpu/gpu_helper.h"
......@@ -117,26 +120,15 @@ constexpr DeviceType kNPU = DeviceType::NPU;
constexpr DeviceType kIPU = DeviceType::IPU;
constexpr DeviceType kMLU = DeviceType::MLU;
class DeviceContext {
public:
virtual ~DeviceContext() PADDLE_MAY_THROW {}
virtual Place GetPlace() const = 0;
virtual void Wait() const {}
};
using DeviceContext = pten::DeviceContext;
class CPUDeviceContext : public DeviceContext {
// using CPUDeviceContext = pten::CPUContext;
// TODO(wilber): The place constructor is used in many places, it is more
// difficult to use CPUDeviceContext = pten::CPUContext directly.
class CPUDeviceContext : public pten::CPUContext {
public:
CPUDeviceContext();
explicit CPUDeviceContext(CPUPlace place);
Eigen::DefaultDevice* eigen_device() const;
Place GetPlace() const override;
private:
CPUPlace place_;
std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
};
template <typename Place>
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/pten/backends/cpu/cpu_context.h"
namespace paddle {
namespace platform {
......@@ -27,6 +28,7 @@ struct ForRange {
void operator()(Function func) const;
};
// NOTE: After the pten kernel is migrated, it needs to be deleted.
template <>
struct ForRange<CPUDeviceContext> {
ForRange(const CPUDeviceContext& dev_ctx, size_t limit) : limit_(limit) {}
......@@ -41,6 +43,20 @@ struct ForRange<CPUDeviceContext> {
size_t limit_;
};
template <>
struct ForRange<pten::CPUContext> {
ForRange(const pten::CPUContext& dev_ctx, size_t limit) : limit_(limit) {}
template <typename Function>
void operator()(Function func) const {
for (size_t i = 0; i < limit_; ++i) {
func(i);
}
}
size_t limit_;
};
#if defined(__NVCC__) || defined(__HIPCC__)
template <typename Function>
__global__ static void ForRangeElemwiseOpGridIsOne(Function func) {
......
......@@ -59,6 +59,7 @@ struct Transform {
BinaryOperation op);
};
// NOTE: After the pten kernel is migrated, it needs to be deleted.
template <>
struct Transform<platform::CPUDeviceContext> {
template <typename InputIter, typename OutputIter, typename UnaryOperation>
......@@ -76,6 +77,23 @@ struct Transform<platform::CPUDeviceContext> {
}
};
template <>
struct Transform<pten::CPUContext> {
template <typename InputIter, typename OutputIter, typename UnaryOperation>
void operator()(const pten::CPUContext& context, InputIter first,
InputIter last, OutputIter result, UnaryOperation op) {
std::transform(first, last, result, op);
}
template <typename InputIter1, typename InputIter2, typename OutputIter,
typename BinaryOperation>
void operator()(const pten::CPUContext& context, InputIter1 first1,
InputIter1 last1, InputIter2 first2, OutputIter result,
BinaryOperation op) {
std::transform(first1, last1, first2, result, op);
}
};
#if defined(__NVCC__) || defined(__HIPCC__)
template <>
struct Transform<platform::CUDADeviceContext> {
......
add_subdirectory(cpu)
cc_library(pten_context SRCS all_context.cc DEPS device_context)
......@@ -24,7 +24,9 @@ limitations under the License. */
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/backends/xpu/xpu_context.h"
// TODO(wilber): DeviceContextPool nees include fluid file.
#include "paddle/fluid/platform/device_context.h"
namespace pten {
using DeviceContext = paddle::platform::DeviceContext;
using DeviceContextPool = paddle::platform::DeviceContextPool;
} // namespace pten
if(WITH_MKLDNN)
# TODO(wilber): support mkldnn context.
cc_library(cpu_context SRCS cpu_context.cc DEPS pten_device_context mkldnn)
else()
cc_library(cpu_context SRCS cpu_context.cc DEPS pten_device_context)
endif()
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/api/ext/exception.h"
// NOTE: The paddle framework should add WITH_EIGEN option to support compile
// without eigen.
#include "paddle/pten/core/device_context.h"
#include "unsupported/Eigen/CXX11/Tensor"
namespace pten {
struct CPUContext::CPUImpl {
Eigen::DefaultDevice* device_{nullptr};
CPUContextResource res_;
CPUPlace place_;
CPUImpl() { device_ = new Eigen::DefaultDevice(); }
// Users need to manage external resources.
explicit CPUImpl(const CPUContextResource& ctx_res) : res_(ctx_res) {
device_ = res_.device;
}
~CPUImpl() {
if (res_.device == nullptr) {
delete device_;
device_ = nullptr;
}
}
Eigen::DefaultDevice* GetEigenDevice() const {
PD_CHECK(device_ != nullptr, "the eigen_device is nullptr.");
return device_;
}
void SetEigenDevice(Eigen::DefaultDevice* device) {
if (device == nullptr) {
return;
}
res_.device = device;
device_ = device;
}
Place GetPlace() const { return place_; }
};
CPUContext::CPUContext() : DeviceContext(), cpu_impl_(nullptr) {
cpu_impl_ = std::make_unique<CPUImpl>();
}
CPUContext::CPUContext(const CPUContext& other)
: DeviceContext(), cpu_impl_(nullptr) {
cpu_impl_ = std::make_unique<CPUImpl>();
cpu_impl_->SetEigenDevice(other.eigen_device());
}
CPUContext::CPUContext(CPUContext&& other)
: DeviceContext(), cpu_impl_(nullptr) {
cpu_impl_ = std::move(other.cpu_impl_);
}
CPUContext::~CPUContext() = default;
CPUContext::CPUContext(const CPUContextResource& ctx_res)
: DeviceContext(), cpu_impl_(nullptr) {
cpu_impl_ = std::make_unique<CPUImpl>(ctx_res);
}
Eigen::DefaultDevice* CPUContext::eigen_device() const {
return cpu_impl_->GetEigenDevice();
}
void CPUContext::SetEigenDevice(Eigen::DefaultDevice* device) {
cpu_impl_->SetEigenDevice(device);
}
Place CPUContext::GetPlace() const { return cpu_impl_->GetPlace(); }
} // namespace pten
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -14,9 +14,47 @@ limitations under the License. */
#pragma once
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/device_context.h"
#include <memory>
#include "paddle/pten/backends/cpu/forwards.h"
#include "paddle/pten/core/device_context.h"
// TODO(wilber): Do we need to use place in pten kernel?
#include "paddle/pten/common/place.h"
namespace pten {
using CPUContext = paddle::platform::CPUDeviceContext;
struct CPUContextResource {
Eigen::DefaultDevice* device{nullptr};
};
class CPUContext : public DeviceContext {
public:
// NOTE: DeviceContext hold resources. Used in training scenarios.
CPUContext();
// NOTE: Share the same underlying resources, please ensure that resources are
// not released.
CPUContext(const CPUContext&);
CPUContext(CPUContext&&);
~CPUContext();
Eigen::DefaultDevice* eigen_device() const;
// TODO(wilber): Whether the interface should be preserved.
Place GetPlace() const override;
public:
// NOTE: External users manage resources. Used in inference scenarios.
explicit CPUContext(const CPUContextResource& ctx_res);
void SetEigenDevice(Eigen::DefaultDevice* device);
private:
struct CPUImpl;
std::unique_ptr<CPUImpl> cpu_impl_;
};
} // namespace pten
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// Forward-declares.
#pragma once
// Forward declaration of Eigen DefaultDevice types.
namespace Eigen {
struct DefaultDevice;
} // namespace Eigen
......@@ -13,8 +13,10 @@ cc_library(arg_map_context SRCS arg_map_context.cc DEPS enforce)
cc_library(tensor_base SRCS tensor_base.cc allocator.cc storage.cc DEPS enforce)
cc_library(tensor_meta SRCS tensor_meta.cc DEPS enforce mixed_vector)
cc_library(dense_tensor SRCS dense_tensor.cc DEPS convert_utils tensor_meta tensor_base)
cc_library(pten_device_context SRCS device_context.cc DEPS tensor_base )
# Will remove once we implemented MKLDNN_Tensor
if(WITH_MKLDNN)
add_dependencies(dense_tensor mkldnn)
add_dependencies(tensor_base mkldnn)
endif()
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/core/device_context.h"
namespace pten {
struct DeviceContext::Impl {
Allocator* allocator_{nullptr};
Impl() = default;
~Impl() = default;
void SetAllocator(Allocator* allocator) { allocator_ = allocator; }
const Allocator& GetAllocator() const { return *allocator_; }
// TODO(Wilber): Add impl. It seems that tensorbase not have interface to
// communicate with allocator.
void Alloc(TensorBase* tensor) {}
};
DeviceContext::DeviceContext() { impl_ = std::make_unique<Impl>(); }
DeviceContext::DeviceContext(const DeviceContext& other) {
impl_->SetAllocator(const_cast<Allocator*>(&other.GetAllocator()));
}
DeviceContext::DeviceContext(DeviceContext&& other) {
impl_ = std::move(other.impl_);
}
DeviceContext::~DeviceContext() = default;
void DeviceContext::SetAllocator(Allocator* allocator) {
impl_->SetAllocator(allocator);
}
const Allocator& DeviceContext::GetAllocator() const {
return impl_->GetAllocator();
}
void DeviceContext::Alloc(TensorBase* tensor) { impl_->Alloc(tensor); }
} // namespace pten
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
// TODO(wilber): Do we need to use place in pten kernel?
#include "paddle/pten/common/place.h"
#include "paddle/pten/core/candidate/allocator.h"
namespace pten {
class TensorBase;
/**
* DeviceContext provides device-related interfaces.
*
* All kernels must access the interfaces provided by the backend through
* DeviceContext.
*/
class DeviceContext {
public:
/**
* @brief Default construct.
*/
DeviceContext();
/**
* @brief Copy construct.
*/
DeviceContext(const DeviceContext&);
/**
* @brief Move construct.
*/
DeviceContext(DeviceContext&&);
/**
* @brief Default destruct.
*/
virtual ~DeviceContext();
/**
* @brief Set the deveice-releated Allocator object.
*
* @param allocator
*/
void SetAllocator(Allocator*);
/**
* @brief Get the const Allocator object.
*
* @return Allocator
*/
const Allocator& GetAllocator() const;
/**
* @brief Allocate memory for tensor.
*/
void Alloc(pten::TensorBase*);
// TODO(wilber): Just for the convenience of migrating the code, it will be
// modified or removed later.
virtual Place GetPlace() const = 0;
// TODO(wilber): The fluid framework uses wait() in many places, how to delete
// this API interface.
virtual void Wait() const {}
private:
struct Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace pten
......@@ -3,3 +3,4 @@ cc_test(test_intrusive_ptr SRCS test_intrusive_ptr.cc)
cc_test(test_type_info SRCS test_type_info.cc)
cc_test(test_convert_utils SRCS test_convert_utils.cc DEPS convert_utils)
cc_test(test_kernel_factory SRCS test_kernel_factory.cc DEPS kernel_factory scale_kernel)
cc_test(test_pten_device_context SRCS test_device_context.cc DEPS pten_context cpu_context)
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "gtest/gtest.h"
// TODO(wilber): will remove after the cpu, gpu context megre.
#include "paddle/pten/backends/cpu/cpu_context.h"
// #include "paddle/pten/backends/all_context.h"
// NOTE: The paddle framework should add WITH_EIGEN option to support compile
// without eigen.
#include "unsupported/Eigen/CXX11/Tensor"
namespace pten {
namespace tests {
TEST(DeviceContext, cpu_context) {
std::cout << "test training scenarios" << std::endl;
{
pten::CPUContext ctx;
CHECK(ctx.eigen_device() != nullptr);
}
std::cout << "test inference scenarios" << std::endl;
Eigen::DefaultDevice* device = new Eigen::DefaultDevice();
{
pten::CPUContextResource ctx_res{device};
pten::CPUContext ctx(ctx_res);
CHECK(ctx.eigen_device() != nullptr);
}
{
pten::CPUContextResource ctx_res{nullptr};
pten::CPUContext ctx(ctx_res);
ctx.SetEigenDevice(device);
CHECK(ctx.eigen_device() != nullptr);
}
delete device;
std::cout << "test copy constructor" << std::endl;
{
pten::CPUContext ctx1;
pten::CPUContext ctx2(ctx1);
CHECK_EQ(ctx1.eigen_device(), ctx2.eigen_device());
}
std::cout << "test move constructor" << std::endl;
{
pten::CPUContext ctx1 = pten::CPUContext();
auto* eigen_device1 = ctx1.eigen_device();
pten::CPUContext ctx2(std::move(ctx1));
auto* eigen_device2 = ctx2.eigen_device();
CHECK_EQ(eigen_device1, eigen_device2);
}
}
} // namespace tests
} // namespace pten
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <gtest/gtest.h>
#include <memory>
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/kernels/cast_kernel.h"
#include "paddle/pten/api/lib/utils/allocator.h"
......@@ -44,16 +45,11 @@ TEST(DEV_API, cast) {
dense_x_data[i] = i * 1.0;
sum += i * 1.0;
}
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(paddle::platform::CPUPlace());
pten::CPUContext dev_ctx;
pten::DataType out_dtype = pten::DataType::FLOAT64;
// 2. test API
auto out = pten::Cast<float>(
*(static_cast<paddle::platform::CPUDeviceContext*>(dev_ctx)),
dense_x,
out_dtype);
auto out = pten::Cast<float>(dev_ctx, dense_x, out_dtype);
// 3. check result
ASSERT_EQ(out.dims().size(), 2);
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#include <gtest/gtest.h>
#include <memory>
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/kernels/complex_kernel.h"
#include "paddle/pten/api/lib/utils/allocator.h"
......@@ -41,13 +42,10 @@ TEST(DEV_API, conj) {
dense_x_data[i] = paddle::complex64(i * 1.0, i * 1.0);
}
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(paddle::platform::CPUPlace());
pten::CPUContext dev_ctx;
// 2. test API
auto out = pten::Conj<paddle::complex64>(
*(static_cast<paddle::platform::CPUDeviceContext*>(dev_ctx)), dense_x);
auto out = pten::Conj<paddle::complex64>(dev_ctx, dense_x);
// 3. check result
ASSERT_EQ(out.dims().size(), 2);
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#include <gtest/gtest.h>
#include <memory>
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/copy_kernel.h"
......@@ -54,9 +55,8 @@ TEST(DEV_API, copy) {
const auto& a = paddle::platform::CPUPlace();
std::cout << typeid(a).name() << std::endl;
// 2. test API
auto& pool = paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.GetByPlace(paddle::platform::CPUPlace());
pten::Copy(*dev_ctx, *(dense_src.get()), false, dense_dst.get());
pten::CPUContext dev_ctx;
pten::Copy(dev_ctx, *(dense_src.get()), false, dense_dst.get());
// 3. check result
for (int64_t i = 0; i < dense_src->numel(); i++) {
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#include <gtest/gtest.h>
#include <memory>
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/kernels/empty_kernel.h"
#include "paddle/pten/kernels/full_kernel.h"
......@@ -30,15 +31,10 @@ using DDim = paddle::framework::DDim;
TEST(DEV_API, empty) {
// 1. create input
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(paddle::platform::CPUPlace());
pten::CPUContext dev_ctx;
// 2. test API
auto out = pten::Empty<float>(
*(static_cast<paddle::platform::CPUDeviceContext*>(dev_ctx)),
{3, 2},
pten::DataType::INT32);
auto out = pten::Empty<float>(dev_ctx, {3, 2}, pten::DataType::INT32);
// 3. check result
ASSERT_EQ(out.dims().size(), 2);
......@@ -59,13 +55,9 @@ TEST(DEV_API, empty_like) {
auto* dense_x_data = dense_x.mutable_data<float>();
dense_x_data[0] = 0;
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(paddle::platform::CPUPlace());
// 2. test API
auto out = pten::EmptyLike<float>(
*(static_cast<paddle::platform::CPUDeviceContext*>(dev_ctx)), dense_x);
pten::CPUContext dev_ctx;
auto out = pten::EmptyLike<float>(dev_ctx, dense_x);
// 3. check result
ASSERT_EQ(out.dims().size(), 2);
......@@ -79,16 +71,9 @@ TEST(DEV_API, full) {
// 1. create input
float val = 1.0;
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(paddle::platform::CPUPlace());
// 2. test API
auto out = pten::Full<float>(
*(static_cast<paddle::platform::CPUDeviceContext*>(dev_ctx)),
{3, 2},
val,
pten::DataType::FLOAT32);
pten::CPUContext dev_ctx;
auto out = pten::Full<float>(dev_ctx, {3, 2}, val, pten::DataType::FLOAT32);
// 3. check result
ASSERT_EQ(out.dims().size(), 2);
......@@ -115,15 +100,10 @@ TEST(DEV_API, full_like) {
dense_x_data[0] = 0;
float val = 1.0;
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(paddle::platform::CPUPlace());
pten::CPUContext dev_ctx;
// 2. test API
auto out = pten::FullLike<float>(
*(static_cast<paddle::platform::CPUDeviceContext*>(dev_ctx)),
dense_x,
val);
auto out = pten::FullLike<float>(dev_ctx, dense_x, val);
// 3. check result
ASSERT_EQ(out.dims().size(), 2);
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#include <gtest/gtest.h>
#include <memory>
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/kernels/dot_kernel.h"
#include "paddle/pten/api/lib/utils/allocator.h"
......@@ -52,15 +53,9 @@ TEST(DEV_API, dot) {
}
}
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(paddle::platform::CPUPlace());
// 2. test API
auto out = pten::Dot<float>(
*(static_cast<paddle::platform::CPUDeviceContext*>(dev_ctx)),
dense_x,
dense_y);
pten::CPUContext dev_ctx;
auto out = pten::Dot<float>(dev_ctx, dense_x, dense_y);
// 3. check result
ASSERT_EQ(out.dims().size(), 2);
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#include <gtest/gtest.h>
#include <memory>
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/kernels/math_kernel.h"
#include "paddle/pten/api/lib/utils/allocator.h"
......@@ -54,16 +55,10 @@ TEST(DEV_API, add) {
dense_y_data[i] = i * 2.0;
}
int axis = 1;
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(paddle::platform::CPUPlace());
// 2. test API
auto dense_out = pten::Add<float>(
*(static_cast<paddle::platform::CPUDeviceContext*>(dev_ctx)),
dense_x,
dense_y,
axis);
pten::CPUContext dev_ctx;
auto dense_out = pten::Add<float>(dev_ctx, dense_x, dense_y, axis);
// 3. check result
ASSERT_EQ(dense_out.dims().size(), 2);
......@@ -107,16 +102,10 @@ TEST(DEV_API, subtract) {
dense_y_data[i] = i * 2.0;
}
int axis = 1;
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(paddle::platform::CPUPlace());
// 2. test API
auto dense_out = pten::Subtract<float>(
*(static_cast<paddle::platform::CPUDeviceContext*>(dev_ctx)),
dense_x,
dense_y,
axis);
pten::CPUContext dev_ctx;
auto dense_out = pten::Subtract<float>(dev_ctx, dense_x, dense_y, axis);
// 3. check result
ASSERT_EQ(dense_out.dims().size(), 2);
......@@ -160,16 +149,10 @@ TEST(DEV_API, divide) {
dense_y_data[i] = i * 2.0 + 1;
}
int axis = 1;
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(paddle::platform::CPUPlace());
// 2. test API
auto dense_out = pten::Divide<float>(
*(static_cast<paddle::platform::CPUDeviceContext*>(dev_ctx)),
dense_x,
dense_y,
axis);
pten::CPUContext dev_ctx;
auto dense_out = pten::Divide<float>(dev_ctx, dense_x, dense_y, axis);
// 3. check result
ASSERT_EQ(dense_out.dims().size(), 2);
......@@ -213,16 +196,10 @@ TEST(DEV_API, multiply) {
dense_y_data[i] = i * 2.0;
}
int axis = 1;
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(paddle::platform::CPUPlace());
// 2. test API
auto dense_out = pten::Multiply<float>(
*(static_cast<paddle::platform::CPUDeviceContext*>(dev_ctx)),
dense_x,
dense_y,
axis);
pten::CPUContext dev_ctx;
auto dense_out = pten::Multiply<float>(dev_ctx, dense_x, dense_y, axis);
// 3. check result
ASSERT_EQ(dense_out.dims().size(), 2);
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#include <gtest/gtest.h>
#include <memory>
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/kernels/flatten_kernel.h"
#include "paddle/pten/api/lib/utils/allocator.h"
......@@ -52,16 +53,10 @@ TEST(DEV_API, flatten) {
dense_x_data[i] = i;
}
int start_axis = 1, stop_axis = 2;
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(paddle::platform::CPUPlace());
pten::CPUContext dev_ctx;
// 2. test API
auto out = pten::Flatten<float>(
*(static_cast<paddle::platform::CPUDeviceContext*>(dev_ctx)),
dense_x,
start_axis,
stop_axis);
auto out = pten::Flatten<float>(dev_ctx, dense_x, start_axis, stop_axis);
// 3. check result
std::vector<int> expect_shape = {3, 4, 3};
......
......@@ -50,13 +50,9 @@ TEST(DEV_API, dot) {
}
std::vector<float> sum(9, 6.0);
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto* ctx = pool.Get(paddle::platform::CPUPlace());
// 2. test API
auto out = Matmul<float, CPUContext>(
*(static_cast<CPUContext*>(ctx)), dense_x, dense_y, false, false);
pten::CPUContext dev_ctx;
auto out = Matmul<float, CPUContext>(dev_ctx, dense_x, dense_y, false, false);
// 3. check result
ASSERT_EQ(out.dims().size(), 2);
......
......@@ -42,17 +42,11 @@ TEST(DEV_API, mean) {
dense_x_data[i] = i * 1.0;
sum += i * 1.0;
}
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(paddle::platform::CPUPlace());
std::vector<int64_t> dims = {0, 1};
// 2. test API
auto out = pten::Mean<float>(
*(static_cast<paddle::platform::CPUDeviceContext*>(dev_ctx)),
dense_x,
dims,
false);
pten::CPUContext dev_ctx;
auto out = pten::Mean<float>(dev_ctx, dense_x, dims, false);
// 3. check result
ASSERT_EQ(out.dims().size(), 1);
......
......@@ -42,16 +42,11 @@ TEST(DEV_API, reshape) {
for (int i = 0; i < dense_x.numel(); i++) {
dense_x_data[i] = i;
}
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(paddle::platform::CPUPlace());
std::vector<int64_t> shape{12, 3};
// 2. test API
auto out = pten::Reshape<float>(
*(static_cast<paddle::platform::CPUDeviceContext*>(dev_ctx)),
dense_x,
shape);
pten::CPUContext dev_ctx;
auto out = pten::Reshape<float>(dev_ctx, dense_x, shape);
// 3. check result
std::vector<int64_t> expect_shape = {12, 3};
ASSERT_EQ(out.dims()[0], expect_shape[0]);
......
......@@ -44,17 +44,10 @@ TEST(DEV_API, scale) {
float bias = 1;
bool bias_after_scale = true;
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(paddle::platform::CPUPlace());
// 2. test API
auto out = pten::Scale<float>(
*(static_cast<paddle::platform::CPUDeviceContext*>(dev_ctx)),
dense_x,
scale,
bias,
bias_after_scale);
pten::CPUContext dev_ctx;
auto out =
pten::Scale<float>(dev_ctx, dense_x, scale, bias, bias_after_scale);
// 3. check result
ASSERT_EQ(out.dims().size(), 2);
......@@ -88,17 +81,10 @@ TEST(DEV_API, scale_host) {
float bias = 1;
bool bias_after_scale = true;
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(paddle::platform::CPUPlace());
// 2. test API
auto out = pten::Scale<float>(
*(static_cast<paddle::platform::CPUDeviceContext*>(dev_ctx)),
dense_x,
scale,
bias,
bias_after_scale);
pten::CPUContext dev_ctx;
auto out =
pten::Scale<float>(dev_ctx, dense_x, scale, bias, bias_after_scale);
// 3. check result
ASSERT_EQ(out.dims().size(), 2);
......
......@@ -42,18 +42,12 @@ TEST(DEV_API, sum) {
dense_x_data[i] = i * 1.0;
sum += i * 1.0;
}
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(paddle::platform::CPUPlace());
std::vector<int64_t> axis = {0, 1};
pten::CPUContext dev_ctx;
// 2. test API
auto out = pten::Sum<float>(
*(static_cast<paddle::platform::CPUDeviceContext*>(dev_ctx)),
dense_x,
axis,
pten::DataType::FLOAT32,
false);
auto out =
pten::Sum<float>(dev_ctx, dense_x, axis, pten::DataType::FLOAT32, false);
// 3. check result
ASSERT_EQ(out.dims().size(), 1);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册