未验证 提交 064bc4b8 编写于 作者: W Wilber 提交者: GitHub

[PTEN] Add cpu context (#38979)

* add cpu_context.

* update

* update

* update

* update

* update

* fix ci problem

* fix npu ci problem

* update

* fix ci compile
上级 08793179
...@@ -42,9 +42,6 @@ namespace framework { ...@@ -42,9 +42,6 @@ namespace framework {
class Scope; class Scope;
class Variable; class Variable;
} // namespace framework } // namespace framework
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle } // namespace paddle
namespace paddle { namespace paddle {
......
...@@ -37,9 +37,6 @@ namespace paddle { ...@@ -37,9 +37,6 @@ namespace paddle {
namespace framework { namespace framework {
class Scope; class Scope;
} // namespace framework } // namespace framework
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle } // namespace paddle
namespace paddle { namespace paddle {
......
...@@ -48,9 +48,6 @@ class Executor; ...@@ -48,9 +48,6 @@ class Executor;
class ProgramDesc; class ProgramDesc;
class Scope; class Scope;
} // namespace framework } // namespace framework
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle } // namespace paddle
DECLARE_double(eager_delete_tensor_gb); DECLARE_double(eager_delete_tensor_gb);
......
...@@ -48,9 +48,6 @@ class Executor; ...@@ -48,9 +48,6 @@ class Executor;
class ProgramDesc; class ProgramDesc;
class Scope; class Scope;
} // namespace framework } // namespace framework
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle } // namespace paddle
namespace paddle { namespace paddle {
......
...@@ -33,31 +33,39 @@ static void ScaleDeviceDispatch(const pten::DenseTensor& dense_tensor, ...@@ -33,31 +33,39 @@ static void ScaleDeviceDispatch(const pten::DenseTensor& dense_tensor,
pten::DenseTensor* dense_out) { pten::DenseTensor* dense_out) {
switch (dense_tensor.dtype()) { switch (dense_tensor.dtype()) {
case pten::DataType::FLOAT64: { case pten::DataType::FLOAT64: {
pten::ScaleKernel<double, DeviceContext>( pten::ScaleKernel<double, typename paddle::framework::
dev_ctx, dense_tensor /* tensor */, scale /* scale */, ConvertToPtenContext<DeviceContext>::TYPE>(
bias /* bias */, bias_after_scale /* bias_after_scale */, static_cast<const typename paddle::framework::ConvertToPtenContext<
dense_out /* out tensor */); DeviceContext>::TYPE&>(dev_ctx),
dense_tensor /* tensor */, scale /* scale */, bias /* bias */,
bias_after_scale /* bias_after_scale */, dense_out /* out tensor */);
break; break;
} }
case pten::DataType::FLOAT32: { case pten::DataType::FLOAT32: {
pten::ScaleKernel<float, DeviceContext>( pten::ScaleKernel<float, typename paddle::framework::ConvertToPtenContext<
dev_ctx, dense_tensor /* tensor */, scale /* scale */, DeviceContext>::TYPE>(
bias /* bias */, bias_after_scale /* bias_after_scale */, static_cast<const typename paddle::framework::ConvertToPtenContext<
dense_out /* out tensor */); DeviceContext>::TYPE&>(dev_ctx),
dense_tensor /* tensor */, scale /* scale */, bias /* bias */,
bias_after_scale /* bias_after_scale */, dense_out /* out tensor */);
break; break;
} }
case pten::DataType::INT64: { case pten::DataType::INT64: {
pten::ScaleKernel<int64_t, DeviceContext>( pten::ScaleKernel<int64_t, typename paddle::framework::
dev_ctx, dense_tensor /* tensor */, scale /* scale */, ConvertToPtenContext<DeviceContext>::TYPE>(
bias /* bias */, bias_after_scale /* bias_after_scale */, static_cast<const typename paddle::framework::ConvertToPtenContext<
dense_out /* out tensor */); DeviceContext>::TYPE&>(dev_ctx),
dense_tensor /* tensor */, scale /* scale */, bias /* bias */,
bias_after_scale /* bias_after_scale */, dense_out /* out tensor */);
break; break;
} }
case pten::DataType::INT32: { case pten::DataType::INT32: {
pten::ScaleKernel<int32_t, DeviceContext>( pten::ScaleKernel<int32_t, typename paddle::framework::
dev_ctx, dense_tensor /* tensor */, scale /* scale */, ConvertToPtenContext<DeviceContext>::TYPE>(
bias /* bias */, bias_after_scale /* bias_after_scale */, static_cast<const typename paddle::framework::ConvertToPtenContext<
dense_out /* out tensor */); DeviceContext>::TYPE&>(dev_ctx),
dense_tensor /* tensor */, scale /* scale */, bias /* bias */,
bias_after_scale /* bias_after_scale */, dense_out /* out tensor */);
break; break;
} }
default: { default: {
......
...@@ -31,9 +31,6 @@ namespace paddle { ...@@ -31,9 +31,6 @@ namespace paddle {
namespace framework { namespace framework {
class Variable; class Variable;
} // namespace framework } // namespace framework
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle } // namespace paddle
namespace pten { namespace pten {
......
...@@ -18,12 +18,6 @@ ...@@ -18,12 +18,6 @@
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
......
...@@ -33,10 +33,6 @@ namespace ir { ...@@ -33,10 +33,6 @@ namespace ir {
class Node; class Node;
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle } // namespace paddle
namespace paddle { namespace paddle {
......
...@@ -28,9 +28,6 @@ namespace ir { ...@@ -28,9 +28,6 @@ namespace ir {
class Node; class Node;
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle } // namespace paddle
namespace paddle { namespace paddle {
......
...@@ -25,12 +25,6 @@ ...@@ -25,12 +25,6 @@
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
namespace paddle {
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -28,9 +28,6 @@ namespace ir { ...@@ -28,9 +28,6 @@ namespace ir {
class Node; class Node;
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle } // namespace paddle
namespace paddle { namespace paddle {
......
...@@ -46,9 +46,6 @@ namespace framework { ...@@ -46,9 +46,6 @@ namespace framework {
class ProgramDesc; class ProgramDesc;
class Scope; class Scope;
} // namespace framework } // namespace framework
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle } // namespace paddle
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
......
...@@ -26,12 +26,6 @@ ...@@ -26,12 +26,6 @@
#include "paddle/fluid/platform/device/mlu/device_context.h" #include "paddle/fluid/platform/device/mlu/device_context.h"
#endif #endif
namespace paddle {
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -18,12 +18,6 @@ limitations under the License. */ ...@@ -18,12 +18,6 @@ limitations under the License. */
#include "paddle/fluid/framework/version.h" #include "paddle/fluid/framework/version.h"
namespace paddle {
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -27,12 +27,6 @@ limitations under the License. */ ...@@ -27,12 +27,6 @@ limitations under the License. */
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
namespace paddle {
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -75,5 +75,16 @@ class KernelArgsNameMaker { ...@@ -75,5 +75,16 @@ class KernelArgsNameMaker {
void SetAllocationForOutputTenosr(pten::DenseTensor* tensor, void SetAllocationForOutputTenosr(pten::DenseTensor* tensor,
const platform::Place& place); const platform::Place& place);
// TODO(Wilber): support others device context.
template <typename T>
struct ConvertToPtenContext {
using TYPE = T;
};
template <>
struct ConvertToPtenContext<platform::CPUDeviceContext> {
using TYPE = pten::CPUContext;
};
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -14,12 +14,6 @@ limitations under the License. */ ...@@ -14,12 +14,6 @@ limitations under the License. */
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
namespace paddle {
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -16,17 +16,13 @@ ...@@ -16,17 +16,13 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
namespace paddle { namespace paddle {
namespace platform {
class DeviceContext;
} // namespace platform
namespace framework { namespace framework {
class Variable; class Variable;
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
namespace paddle { namespace paddle {
......
...@@ -173,6 +173,11 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -173,6 +173,11 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
<< " | kernel key: " << pt_kernel_key << " | kernel key: " << pt_kernel_key
<< " | kernel: " << pt_kernel; << " | kernel: " << pt_kernel;
if (platform::is_cpu_place(expected_kernel_key.place_)) {
auto* cpu_ctx = pool.Get(paddle::platform::CPUPlace());
return PreparedOp(op, ctx, expected_kernel_key, pt_kernel_signature,
pt_kernel, cpu_ctx);
}
// TODO(chenweihang): using CPUKernel when miss device kernel case // TODO(chenweihang): using CPUKernel when miss device kernel case
return PreparedOp(op, ctx, expected_kernel_key, pt_kernel_signature, return PreparedOp(op, ctx, expected_kernel_key, pt_kernel_signature,
pt_kernel, dev_ctx); pt_kernel, dev_ctx);
......
...@@ -37,9 +37,6 @@ namespace paddle { ...@@ -37,9 +37,6 @@ namespace paddle {
namespace framework { namespace framework {
class Variable; class Variable;
} // namespace framework } // namespace framework
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle } // namespace paddle
namespace paddle { namespace paddle {
......
...@@ -33,11 +33,6 @@ ...@@ -33,11 +33,6 @@
#include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/for_range.h"
namespace paddle { namespace paddle {
namespace platform {
class DeviceContext;
} // namespace platform
namespace imperative { namespace imperative {
class ParallelContext; class ParallelContext;
class VarBase; class VarBase;
......
...@@ -19,13 +19,9 @@ limitations under the License. */ ...@@ -19,13 +19,9 @@ limitations under the License. */
#include "paddle/fluid/memory/allocation/allocator.h" #include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/stream/stream.h" #include "paddle/fluid/platform/stream/stream.h"
#include "paddle/pten/core/device_context.h"
namespace paddle { namespace paddle {
namespace platform {
class DeviceContext;
} // platform
namespace memory { namespace memory {
using pten::Allocation; using pten::Allocation;
...@@ -37,7 +33,7 @@ extern std::shared_ptr<Allocation> AllocShared(const platform::Place& place, ...@@ -37,7 +33,7 @@ extern std::shared_ptr<Allocation> AllocShared(const platform::Place& place,
extern AllocationPtr Alloc(const platform::Place& place, size_t size); extern AllocationPtr Alloc(const platform::Place& place, size_t size);
extern AllocationPtr Alloc(const platform::DeviceContext& dev_ctx, size_t size); extern AllocationPtr Alloc(const pten::DeviceContext& dev_ctx, size_t size);
extern uint64_t Release(const platform::Place& place); extern uint64_t Release(const platform::Place& place);
......
...@@ -19,12 +19,6 @@ limitations under the License. */ ...@@ -19,12 +19,6 @@ limitations under the License. */
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace pten { namespace pten {
class DenseTensor; class DenseTensor;
} // namespace pten } // namespace pten
......
...@@ -67,7 +67,10 @@ class CastOpKernel : public framework::OpKernel<InT> { ...@@ -67,7 +67,10 @@ class CastOpKernel : public framework::OpKernel<InT> {
static_cast<framework::proto::VarType::Type>(out_dtype)); static_cast<framework::proto::VarType::Type>(out_dtype));
// call new kernel // call new kernel
pten::CastKernel<InT>(dev_ctx, *in, pt_out_dtype, out); pten::CastKernel<InT>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
*in, pt_out_dtype, out);
} }
}; };
......
...@@ -202,7 +202,10 @@ class CholeskySolveGradKernel : public framework::OpKernel<T> { ...@@ -202,7 +202,10 @@ class CholeskySolveGradKernel : public framework::OpKernel<T> {
commonterm_for_range(commonterm_functor); commonterm_for_range(commonterm_functor);
commonterm_conj = helper.Transpose(commonterm_conj); commonterm_conj = helper.Transpose(commonterm_conj);
pten::AddKernel<T>(dev_ctx, commonterm, commonterm_conj, -1, &commonterm); pten::AddKernel<T>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE &>(dev_ctx),
commonterm, commonterm_conj, -1, &commonterm);
auto mat_dim_u = math::CreateMatrixDescriptor(u_bst.dims(), 0, false); auto mat_dim_u = math::CreateMatrixDescriptor(u_bst.dims(), 0, false);
auto mat_dim_c = auto mat_dim_c =
......
...@@ -36,7 +36,10 @@ class ConjKernel : public framework::OpKernel<T> { ...@@ -36,7 +36,10 @@ class ConjKernel : public framework::OpKernel<T> {
auto& dev_ctx = context.device_context<DeviceContext>(); auto& dev_ctx = context.device_context<DeviceContext>();
// call new kernel // call new kernel
pten::ConjKernel<T>(dev_ctx, *x, out); pten::ConjKernel<T>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
*x, out);
} }
}; };
......
...@@ -41,7 +41,11 @@ class DotKernel : public framework::OpKernel<T> { ...@@ -41,7 +41,11 @@ class DotKernel : public framework::OpKernel<T> {
out->mutable_data<T>(x->place()); out->mutable_data<T>(x->place());
// call new kernel // call new kernel
pten::DotKernel<T, DeviceContext>(dev_ctx, *x, *y, out); pten::DotKernel<T, typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
*x, *y, out);
} }
}; };
...@@ -61,8 +65,10 @@ class DotGradKernel : public framework::OpKernel<T> { ...@@ -61,8 +65,10 @@ class DotGradKernel : public framework::OpKernel<T> {
auto& dev_ctx = ctx.device_context<DeviceContext>(); auto& dev_ctx = ctx.device_context<DeviceContext>();
// call new kernel // call new kernel
pten::DotGradKernel<T>(dev_ctx, *tensor_x, *tensor_y, *tensor_dout, pten::DotGradKernel<T>(
tensor_dx, tensor_dy); static_cast<const typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
*tensor_x, *tensor_y, *tensor_dout, tensor_dx, tensor_dy);
} }
}; };
......
...@@ -61,7 +61,10 @@ class ElementwiseAddKernel : public framework::OpKernel<T> { ...@@ -61,7 +61,10 @@ class ElementwiseAddKernel : public framework::OpKernel<T> {
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x); auto pt_x = paddle::experimental::MakePtenDenseTensor(*x);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*y); auto pt_y = paddle::experimental::MakePtenDenseTensor(*y);
auto pt_z = paddle::experimental::MakePtenDenseTensor(*z); auto pt_z = paddle::experimental::MakePtenDenseTensor(*z);
pten::AddKernel<T>(dev_ctx, *pt_x.get(), *pt_y.get(), axis, pt_z.get()); pten::AddKernel<T>(
static_cast<const typename framework::ConvertToPtenContext<
DeviceContext>::TYPE &>(dev_ctx),
*pt_x.get(), *pt_y.get(), axis, pt_z.get());
} }
}; };
......
...@@ -51,7 +51,10 @@ class ElementwiseDivKernel : public framework::OpKernel<T> { ...@@ -51,7 +51,10 @@ class ElementwiseDivKernel : public framework::OpKernel<T> {
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x); auto pt_x = paddle::experimental::MakePtenDenseTensor(*x);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*y); auto pt_y = paddle::experimental::MakePtenDenseTensor(*y);
auto pt_z = paddle::experimental::MakePtenDenseTensor(*z); auto pt_z = paddle::experimental::MakePtenDenseTensor(*z);
pten::DivideKernel<T>(dev_ctx, *pt_x.get(), *pt_y.get(), axis, pt_z.get()); pten::DivideKernel<T>(
static_cast<const typename framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
*pt_x.get(), *pt_y.get(), axis, pt_z.get());
} }
}; };
......
...@@ -124,8 +124,10 @@ class ElementwiseMulKernel : public framework::OpKernel<T> { ...@@ -124,8 +124,10 @@ class ElementwiseMulKernel : public framework::OpKernel<T> {
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x_lod); auto pt_x = paddle::experimental::MakePtenDenseTensor(*x_lod);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*y); auto pt_y = paddle::experimental::MakePtenDenseTensor(*y);
auto pt_z = paddle::experimental::MakePtenDenseTensor(*z_lod); auto pt_z = paddle::experimental::MakePtenDenseTensor(*z_lod);
pten::MultiplyKernel<T>(dev_ctx, *pt_x.get(), *pt_y.get(), axis, pten::MultiplyKernel<T>(
pt_z.get()); static_cast<const typename framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
*pt_x.get(), *pt_y.get(), axis, pt_z.get());
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"X's type[%s] is not supported by elementwise_op. X's type should be " "X's type[%s] is not supported by elementwise_op. X's type should be "
......
...@@ -51,8 +51,10 @@ class ElementwiseSubKernel : public framework::OpKernel<T> { ...@@ -51,8 +51,10 @@ class ElementwiseSubKernel : public framework::OpKernel<T> {
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x); auto pt_x = paddle::experimental::MakePtenDenseTensor(*x);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*y); auto pt_y = paddle::experimental::MakePtenDenseTensor(*y);
auto pt_z = paddle::experimental::MakePtenDenseTensor(*z); auto pt_z = paddle::experimental::MakePtenDenseTensor(*z);
pten::SubtractKernel<T>(dev_ctx, *pt_x.get(), *pt_y.get(), axis, pten::SubtractKernel<T>(
pt_z.get()); static_cast<const typename framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
*pt_x.get(), *pt_y.get(), axis, pt_z.get());
} }
}; };
......
...@@ -62,7 +62,10 @@ class FillAnyLikeKernel : public framework::OpKernel<T> { ...@@ -62,7 +62,10 @@ class FillAnyLikeKernel : public framework::OpKernel<T> {
const auto& dev_ctx = context.template device_context<DeviceContext>(); const auto& dev_ctx = context.template device_context<DeviceContext>();
// call new kernel // call new kernel
pten::FullLikeKernel<T>(dev_ctx, value, out); pten::FullLikeKernel<T>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
value, out);
} }
}; };
......
...@@ -57,9 +57,9 @@ class FillConstantBatchSizeLikeOpKernel : public framework::OpKernel<T> { ...@@ -57,9 +57,9 @@ class FillConstantBatchSizeLikeOpKernel : public framework::OpKernel<T> {
} }
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(ctx.GetPlace());
bool cpu_place = force_cpu || ctx.GetPlace() == platform::CPUPlace(); bool cpu_place = force_cpu || ctx.GetPlace() == platform::CPUPlace();
if (cpu_place) { if (cpu_place) {
auto &dev_ctx = *pool.Get(platform::CPUPlace());
math::SetConstant<platform::CPUDeviceContext, T> functor; math::SetConstant<platform::CPUDeviceContext, T> functor;
out->mutable_data(platform::CPUPlace(), data_type); out->mutable_data(platform::CPUPlace(), data_type);
functor(reinterpret_cast<const platform::CPUDeviceContext &>(dev_ctx), functor(reinterpret_cast<const platform::CPUDeviceContext &>(dev_ctx),
...@@ -67,6 +67,7 @@ class FillConstantBatchSizeLikeOpKernel : public framework::OpKernel<T> { ...@@ -67,6 +67,7 @@ class FillConstantBatchSizeLikeOpKernel : public framework::OpKernel<T> {
} }
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (!cpu_place) { if (!cpu_place) {
auto &dev_ctx = *pool.Get(ctx.GetPlace());
math::SetConstant<platform::CUDADeviceContext, T> functor; math::SetConstant<platform::CUDADeviceContext, T> functor;
out->mutable_data(ctx.GetPlace(), data_type); out->mutable_data(ctx.GetPlace(), data_type);
functor(reinterpret_cast<const platform::CUDADeviceContext &>(dev_ctx), functor(reinterpret_cast<const platform::CUDADeviceContext &>(dev_ctx),
......
...@@ -67,9 +67,9 @@ class FillConstantBatchSizeLikeOpNPUKernel : public framework::OpKernel<T> { ...@@ -67,9 +67,9 @@ class FillConstantBatchSizeLikeOpNPUKernel : public framework::OpKernel<T> {
} }
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(ctx.GetPlace());
bool cpu_place = force_cpu || ctx.GetPlace() == platform::CPUPlace(); bool cpu_place = force_cpu || ctx.GetPlace() == platform::CPUPlace();
if (cpu_place) { if (cpu_place) {
auto &dev_ctx = *pool.Get(platform::CPUPlace());
math::SetConstant<platform::CPUDeviceContext, T> functor; math::SetConstant<platform::CPUDeviceContext, T> functor;
out->mutable_data(platform::CPUPlace(), data_type); out->mutable_data(platform::CPUPlace(), data_type);
functor(reinterpret_cast<const platform::CPUDeviceContext &>(dev_ctx), functor(reinterpret_cast<const platform::CPUDeviceContext &>(dev_ctx),
......
...@@ -102,7 +102,6 @@ class FillConstantKernel : public framework::OpKernel<T> { ...@@ -102,7 +102,6 @@ class FillConstantKernel : public framework::OpKernel<T> {
} }
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(ctx.GetPlace());
int actual_place = place_type; int actual_place = place_type;
if (actual_place == -1) { if (actual_place == -1) {
...@@ -123,12 +122,14 @@ class FillConstantKernel : public framework::OpKernel<T> { ...@@ -123,12 +122,14 @@ class FillConstantKernel : public framework::OpKernel<T> {
: "<T>"); : "<T>");
tensor->mutable_data(platform::CPUPlace(), data_type); tensor->mutable_data(platform::CPUPlace(), data_type);
math::SetConstant<platform::CPUDeviceContext, T> functor; math::SetConstant<platform::CPUDeviceContext, T> functor;
auto &dev_ctx = *pool.Get(platform::CPUPlace());
functor(reinterpret_cast<const platform::CPUDeviceContext &>(dev_ctx), functor(reinterpret_cast<const platform::CPUDeviceContext &>(dev_ctx),
tensor, static_cast<T>(value)); tensor, static_cast<T>(value));
} else if (actual_place == 1) { } else if (actual_place == 1) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
tensor->mutable_data(ctx.GetPlace(), data_type); tensor->mutable_data(ctx.GetPlace(), data_type);
math::SetConstant<platform::CUDADeviceContext, T> functor; math::SetConstant<platform::CUDADeviceContext, T> functor;
auto &dev_ctx = *pool.Get(ctx.GetPlace());
functor(reinterpret_cast<const platform::CUDADeviceContext &>(dev_ctx), functor(reinterpret_cast<const platform::CUDADeviceContext &>(dev_ctx),
tensor, static_cast<T>(value)); tensor, static_cast<T>(value));
#else #else
...@@ -138,8 +139,10 @@ class FillConstantKernel : public framework::OpKernel<T> { ...@@ -138,8 +139,10 @@ class FillConstantKernel : public framework::OpKernel<T> {
} else if (actual_place == 2) { } else if (actual_place == 2) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
tensor->mutable_data(platform::CUDAPinnedPlace(), data_type); tensor->mutable_data(platform::CUDAPinnedPlace(), data_type);
math::SetConstant<platform::CPUDeviceContext, T> functor; math::SetConstant<platform::CUDAPinnedDeviceContext, T> functor;
functor(reinterpret_cast<const platform::CPUDeviceContext &>(dev_ctx), auto &dev_ctx = *pool.Get(platform::CUDAPinnedPlace());
functor(
reinterpret_cast<const platform::CUDAPinnedDeviceContext &>(dev_ctx),
tensor, static_cast<T>(value)); tensor, static_cast<T>(value));
#else #else
PADDLE_THROW(platform::errors::PreconditionNotMet( PADDLE_THROW(platform::errors::PreconditionNotMet(
...@@ -149,6 +152,7 @@ class FillConstantKernel : public framework::OpKernel<T> { ...@@ -149,6 +152,7 @@ class FillConstantKernel : public framework::OpKernel<T> {
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
tensor->mutable_data(ctx.GetPlace(), data_type); tensor->mutable_data(ctx.GetPlace(), data_type);
math::SetConstant<platform::XPUDeviceContext, T> functor; math::SetConstant<platform::XPUDeviceContext, T> functor;
auto &dev_ctx = *pool.Get(ctx.GetPlace());
functor(reinterpret_cast<const platform::XPUDeviceContext &>(dev_ctx), functor(reinterpret_cast<const platform::XPUDeviceContext &>(dev_ctx),
tensor, static_cast<T>(value)); tensor, static_cast<T>(value));
#else #else
......
...@@ -132,8 +132,11 @@ class FlattenContiguousRangeKernel : public framework::OpKernel<T> { ...@@ -132,8 +132,11 @@ class FlattenContiguousRangeKernel : public framework::OpKernel<T> {
auto &dev_ctx = context.device_context<DeviceContext>(); auto &dev_ctx = context.device_context<DeviceContext>();
// call new kernel // call new kernel
pten::FlattenKernel<T, DeviceContext>(dev_ctx, *in, start_axis, stop_axis, pten::FlattenKernel<T, typename paddle::framework::ConvertToPtenContext<
out); DeviceContext>::TYPE>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE &>(dev_ctx),
*in, start_axis, stop_axis, out);
} }
}; };
...@@ -150,7 +153,11 @@ class FlattenContiguousRangeGradKernel : public framework::OpKernel<T> { ...@@ -150,7 +153,11 @@ class FlattenContiguousRangeGradKernel : public framework::OpKernel<T> {
auto &dev_ctx = ctx.device_context<DeviceContext>(); auto &dev_ctx = ctx.device_context<DeviceContext>();
// call new kernel // call new kernel
pten::FlattenGradKernel<T, DeviceContext>(dev_ctx, *d_out, *xshape, d_x); pten::FlattenGradKernel<T, typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE &>(dev_ctx),
*d_out, *xshape, d_x);
} }
}; };
......
...@@ -31,7 +31,6 @@ namespace paddle { ...@@ -31,7 +31,6 @@ namespace paddle {
namespace platform { namespace platform {
class CPUDeviceContext; class CPUDeviceContext;
class CUDADeviceContext; class CUDADeviceContext;
class DeviceContext;
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
......
...@@ -221,7 +221,11 @@ void Tensor_Add(const DeviceContext& dev_ctx, const framework::Tensor& src1, ...@@ -221,7 +221,11 @@ void Tensor_Add(const DeviceContext& dev_ctx, const framework::Tensor& src1,
out->Resize(src1.dims()); out->Resize(src1.dims());
out->mutable_data<T>(dev_ctx.GetPlace()); out->mutable_data<T>(dev_ctx.GetPlace());
pten::AddKernel<T, DeviceContext>(dev_ctx, src1, src2, -1, out); pten::AddKernel<
T, typename paddle::framework::ConvertToPtenContext<DeviceContext>::TYPE>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
src1, src2, -1, out);
} }
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
...@@ -230,7 +234,11 @@ void Tensor_Sub(const DeviceContext& dev_ctx, const framework::Tensor& src1, ...@@ -230,7 +234,11 @@ void Tensor_Sub(const DeviceContext& dev_ctx, const framework::Tensor& src1,
out->Resize(src1.dims()); out->Resize(src1.dims());
out->mutable_data<T>(dev_ctx.GetPlace()); out->mutable_data<T>(dev_ctx.GetPlace());
pten::SubtractKernel<T, DeviceContext>(dev_ctx, src1, src2, -1, out); pten::SubtractKernel<
T, typename paddle::framework::ConvertToPtenContext<DeviceContext>::TYPE>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
src1, src2, -1, out);
} }
template <typename DeviceContext, typename T, size_t D> template <typename DeviceContext, typename T, size_t D>
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "paddle/pten/backends/cpu/cpu_context.h"
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#include <mkl.h> #include <mkl.h>
#endif #endif
...@@ -819,6 +820,12 @@ T *Blas<platform::CPUDeviceContext>::GEMM_ALLOC(const CBLAS_IDENTIFIER id, ...@@ -819,6 +820,12 @@ T *Blas<platform::CPUDeviceContext>::GEMM_ALLOC(const CBLAS_IDENTIFIER id,
const int K) const { const int K) const {
return CBlas<T>::GEMM_ALLOC(id, M, N, K); return CBlas<T>::GEMM_ALLOC(id, M, N, K);
} }
template <>
template <typename T>
T *Blas<pten::CPUContext>::GEMM_ALLOC(const CBLAS_IDENTIFIER id, const int M,
const int N, const int K) const {
return CBlas<T>::GEMM_ALLOC(id, M, N, K);
}
template <> template <>
template <typename T> template <typename T>
...@@ -829,6 +836,15 @@ void Blas<platform::CPUDeviceContext>::GEMM_PACK(const CBLAS_IDENTIFIER id, ...@@ -829,6 +836,15 @@ void Blas<platform::CPUDeviceContext>::GEMM_PACK(const CBLAS_IDENTIFIER id,
const int ld, T *dst) const { const int ld, T *dst) const {
CBlas<T>::GEMM_PACK(CblasRowMajor, id, trans, M, N, K, alpha, src, ld, dst); CBlas<T>::GEMM_PACK(CblasRowMajor, id, trans, M, N, K, alpha, src, ld, dst);
} }
template <>
template <typename T>
void Blas<pten::CPUContext>::GEMM_PACK(const CBLAS_IDENTIFIER id,
const CBLAS_TRANSPOSE trans, int M,
int N, int K, const T alpha,
const T *src, const int ld,
T *dst) const {
CBlas<T>::GEMM_PACK(CblasRowMajor, id, trans, M, N, K, alpha, src, ld, dst);
}
template <> template <>
template <typename T> template <typename T>
...@@ -838,12 +854,26 @@ void Blas<platform::CPUDeviceContext>::GEMM_COMPUTE( ...@@ -838,12 +854,26 @@ void Blas<platform::CPUDeviceContext>::GEMM_COMPUTE(
CBlas<T>::GEMM_COMPUTE(CblasRowMajor, transA, transB, M, N, K, A, lda, B, ldb, CBlas<T>::GEMM_COMPUTE(CblasRowMajor, transA, transB, M, N, K, A, lda, B, ldb,
beta, C, ldc); beta, C, ldc);
} }
template <>
template <typename T>
void Blas<pten::CPUContext>::GEMM_COMPUTE(int transA, int transB, int M, int N,
int K, const T *A, const int lda,
const T *B, const int ldb, T beta,
T *C, const int ldc) const {
CBlas<T>::GEMM_COMPUTE(CblasRowMajor, transA, transB, M, N, K, A, lda, B, ldb,
beta, C, ldc);
}
template <> template <>
template <typename T> template <typename T>
void Blas<platform::CPUDeviceContext>::GEMM_FREE(T *data) const { void Blas<platform::CPUDeviceContext>::GEMM_FREE(T *data) const {
CBlas<T>::GEMM_FREE(data); CBlas<T>::GEMM_FREE(data);
} }
template <>
template <typename T>
void Blas<pten::CPUContext>::GEMM_FREE(T *data) const {
CBlas<T>::GEMM_FREE(data);
}
#endif #endif
template <> template <>
...@@ -858,6 +888,18 @@ void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA, ...@@ -858,6 +888,18 @@ void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
CBlas<T>::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, CBlas<T>::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc); beta, C, ldc);
} }
template <>
template <typename T>
void Blas<pten::CPUContext>::GEMM(CBLAS_TRANSPOSE transA,
CBLAS_TRANSPOSE transB, int M, int N, int K,
T alpha, const T *A, const T *B, T beta,
T *C) const {
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
CBlas<T>::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc);
}
template <> template <>
template <typename T> template <typename T>
...@@ -869,6 +911,15 @@ void Blas<platform::CPUDeviceContext>::GEMM(bool transA, bool transB, int M, ...@@ -869,6 +911,15 @@ void Blas<platform::CPUDeviceContext>::GEMM(bool transA, bool transB, int M,
transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A, transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A,
lda, B, ldb, beta, C, ldc); lda, B, ldb, beta, C, ldc);
} }
template <>
template <typename T>
void Blas<pten::CPUContext>::GEMM(bool transA, bool transB, int M, int N, int K,
T alpha, const T *A, int lda, const T *B,
int ldb, T beta, T *C, int ldc) const {
CBlas<T>::GEMM(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans,
transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A,
lda, B, ldb, beta, C, ldc);
}
template <> template <>
template <typename T> template <typename T>
...@@ -880,6 +931,15 @@ void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA, ...@@ -880,6 +931,15 @@ void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
CBlas<T>::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, CBlas<T>::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc); beta, C, ldc);
} }
template <>
template <typename T>
void Blas<pten::CPUContext>::GEMM(CBLAS_TRANSPOSE transA,
CBLAS_TRANSPOSE transB, int M, int N, int K,
T alpha, const T *A, int lda, const T *B,
int ldb, T beta, T *C, int ldc) const {
CBlas<T>::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc);
}
template <typename DeviceContext> template <typename DeviceContext>
template <typename T> template <typename T>
...@@ -920,12 +980,22 @@ void Blas<platform::CPUDeviceContext>::AXPY(int n, T alpha, const T *x, ...@@ -920,12 +980,22 @@ void Blas<platform::CPUDeviceContext>::AXPY(int n, T alpha, const T *x,
T *y) const { T *y) const {
CBlas<T>::AXPY(n, alpha, x, 1, y, 1); CBlas<T>::AXPY(n, alpha, x, 1, y, 1);
} }
template <>
template <typename T>
void Blas<pten::CPUContext>::AXPY(int n, T alpha, const T *x, T *y) const {
CBlas<T>::AXPY(n, alpha, x, 1, y, 1);
}
template <> template <>
template <typename T> template <typename T>
void Blas<platform::CPUDeviceContext>::VCOPY(int n, const T *x, T *y) const { void Blas<platform::CPUDeviceContext>::VCOPY(int n, const T *x, T *y) const {
CBlas<T>::VCOPY(n, x, 1, y, 1); CBlas<T>::VCOPY(n, x, 1, y, 1);
} }
template <>
template <typename T>
void Blas<pten::CPUContext>::VCOPY(int n, const T *x, T *y) const {
CBlas<T>::VCOPY(n, x, 1, y, 1);
}
template <> template <>
template <typename T> template <typename T>
...@@ -942,6 +1012,20 @@ void Blas<platform::CPUDeviceContext>::VADD(int n, const T *x, const T *y, ...@@ -942,6 +1012,20 @@ void Blas<platform::CPUDeviceContext>::VADD(int n, const T *x, const T *y,
} }
#endif #endif
} }
template <>
template <typename T>
void Blas<pten::CPUContext>::VADD(int n, const T *x, const T *y, T *z) const {
#ifdef PADDLE_WITH_MKLML
CBlas<T>::VADD(n, x, y, z);
#else
if (x == z) {
this->template AXPY<T>(n, (T)(1.), y, z);
} else {
this->template VCOPY<T>(n, y, z);
this->template AXPY<T>(n, (T)(1.), x, z);
}
#endif
}
template <> template <>
template <typename T> template <typename T>
...@@ -956,6 +1040,18 @@ void Blas<platform::CPUDeviceContext>::VSUB(int n, const T *x, const T *y, ...@@ -956,6 +1040,18 @@ void Blas<platform::CPUDeviceContext>::VSUB(int n, const T *x, const T *y,
} }
#endif #endif
} }
template <>
template <typename T>
void Blas<pten::CPUContext>::VSUB(int n, const T *x, const T *y, T *z) const {
#ifdef PADDLE_WITH_MKLML
CBlas<T>::VSUB(n, x, y, z);
#else
// try to find if openblas support vsub
for (int i = 0; i < n; ++i) {
z[i] = x[i] - y[i];
}
#endif
}
template <> template <>
template <typename T> template <typename T>
...@@ -970,6 +1066,18 @@ void Blas<platform::CPUDeviceContext>::VMUL(int n, const T *x, const T *y, ...@@ -970,6 +1066,18 @@ void Blas<platform::CPUDeviceContext>::VMUL(int n, const T *x, const T *y,
} }
#endif #endif
} }
template <>
template <typename T>
void Blas<pten::CPUContext>::VMUL(int n, const T *x, const T *y, T *z) const {
#ifdef PADDLE_WITH_MKLML
CBlas<T>::VMUL(n, x, y, z);
#else
// try to find if openblas support vmul
for (int i = 0; i < n; ++i) {
z[i] = x[i] * y[i];
}
#endif
}
template <> template <>
template <typename T> template <typename T>
...@@ -984,6 +1092,18 @@ void Blas<platform::CPUDeviceContext>::VDIV(int n, const T *x, const T *y, ...@@ -984,6 +1092,18 @@ void Blas<platform::CPUDeviceContext>::VDIV(int n, const T *x, const T *y,
} }
#endif #endif
} }
template <>
template <typename T>
void Blas<pten::CPUContext>::VDIV(int n, const T *x, const T *y, T *z) const {
#ifdef PADDLE_WITH_MKLML
CBlas<T>::VDIV(n, x, y, z);
#else
// try to find if openblas support vdiv
for (int i = 0; i < n; ++i) {
z[i] = x[i] / y[i];
}
#endif
}
template <> template <>
template <typename T> template <typename T>
...@@ -997,6 +1117,18 @@ void Blas<platform::CPUDeviceContext>::VEXP(int n, const T *x, T *y) const { ...@@ -997,6 +1117,18 @@ void Blas<platform::CPUDeviceContext>::VEXP(int n, const T *x, T *y) const {
} }
#endif #endif
} }
template <>
template <typename T>
void Blas<pten::CPUContext>::VEXP(int n, const T *x, T *y) const {
#ifdef PADDLE_WITH_MKLML
CBlas<T>::VEXP(n, x, y);
#else
// try to find if openblas support vexp
for (int i = 0; i < n; ++i) {
y[i] = std::exp(x[i]);
}
#endif
}
template <> template <>
template <typename T> template <typename T>
...@@ -1009,6 +1141,17 @@ void Blas<platform::CPUDeviceContext>::VSQUARE(int n, const T *x, T *y) const { ...@@ -1009,6 +1141,17 @@ void Blas<platform::CPUDeviceContext>::VSQUARE(int n, const T *x, T *y) const {
} }
#endif #endif
} }
template <>
template <typename T>
void Blas<pten::CPUContext>::VSQUARE(int n, const T *x, T *y) const {
#ifdef PADDLE_WITH_MKLML
CBlas<T>::VSQUARE(n, x, y);
#else
for (int i = 0; i < n; ++i) {
y[i] = x[i] * x[i];
}
#endif
}
template <> template <>
template <typename T> template <typename T>
...@@ -1022,6 +1165,17 @@ void Blas<platform::CPUDeviceContext>::VPOW(int n, const T *x, T a, ...@@ -1022,6 +1165,17 @@ void Blas<platform::CPUDeviceContext>::VPOW(int n, const T *x, T a,
} }
#endif #endif
} }
template <>
template <typename T>
void Blas<pten::CPUContext>::VPOW(int n, const T *x, T a, T *y) const {
#ifdef PADDLE_WITH_MKLML
CBlas<T>::VPOW(n, x, a, y);
#else
for (int i = 0; i < n; ++i) {
y[i] = std::pow(x[i], a);
}
#endif
}
template <> template <>
template <typename T> template <typename T>
...@@ -1037,6 +1191,20 @@ T Blas<platform::CPUDeviceContext>::DOT(int n, const T *x, const T *y) const { ...@@ -1037,6 +1191,20 @@ T Blas<platform::CPUDeviceContext>::DOT(int n, const T *x, const T *y) const {
return sum; return sum;
#endif #endif
} }
template <>
template <typename T>
T Blas<pten::CPUContext>::DOT(int n, const T *x, const T *y) const {
#ifdef PADDLE_WITH_MKLML
return CBlas<T>::DOT(n, x, 1, y, 1);
#else
// try to find if openblas support cblas_dot
T sum = 0;
for (int i = 0; i < n; ++i) {
sum += x[i] * y[i];
}
return sum;
#endif
}
template <> template <>
template <typename T> template <typename T>
...@@ -1050,6 +1218,18 @@ void Blas<platform::CPUDeviceContext>::SCAL(int n, const T a, T *x) const { ...@@ -1050,6 +1218,18 @@ void Blas<platform::CPUDeviceContext>::SCAL(int n, const T a, T *x) const {
} }
#endif #endif
} }
template <>
template <typename T>
void Blas<pten::CPUContext>::SCAL(int n, const T a, T *x) const {
#ifdef PADDLE_WITH_MKLML
CBlas<T>::SCAL(n, a, x, 1);
#else
// try to find if openblas support cblas_scal
for (int i = 0; i < n; ++i) {
x[i] = a * x[i];
}
#endif
}
template <> template <>
template <typename T> template <typename T>
...@@ -1065,6 +1245,20 @@ T Blas<platform::CPUDeviceContext>::ASUM(int n, T *x, int inc) const { ...@@ -1065,6 +1245,20 @@ T Blas<platform::CPUDeviceContext>::ASUM(int n, T *x, int inc) const {
#endif #endif
return sum; return sum;
} }
template <>
template <typename T>
T Blas<pten::CPUContext>::ASUM(int n, T *x, int inc) const {
auto sum = static_cast<T>(0.0);
#ifdef PADDLE_WITH_MKLML
sum = CBlas<T>::ASUM(n, x, inc);
#else
// TODO(jczaja): check if openblas does provide cblas_sasum/cblas_dasum
for (int c = 0; c < n; ++c) {
sum += x[c];
}
#endif
return sum;
}
template <> template <>
template <typename T> template <typename T>
...@@ -1074,6 +1268,13 @@ void Blas<platform::CPUDeviceContext>::GEMV(bool trans_a, int M, int N, T alpha, ...@@ -1074,6 +1268,13 @@ void Blas<platform::CPUDeviceContext>::GEMV(bool trans_a, int M, int N, T alpha,
CBLAS_TRANSPOSE transA = !trans_a ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE transA = !trans_a ? CblasNoTrans : CblasTrans;
CBlas<T>::GEMV(CblasRowMajor, transA, M, N, alpha, A, N, B, 1, beta, C, 1); CBlas<T>::GEMV(CblasRowMajor, transA, M, N, alpha, A, N, B, 1, beta, C, 1);
} }
template <>
template <typename T>
void Blas<pten::CPUContext>::GEMV(bool trans_a, int M, int N, T alpha,
const T *A, const T *B, T beta, T *C) const {
CBLAS_TRANSPOSE transA = !trans_a ? CblasNoTrans : CblasTrans;
CBlas<T>::GEMV(CblasRowMajor, transA, M, N, alpha, A, N, B, 1, beta, C, 1);
}
template <> template <>
template <typename T> template <typename T>
...@@ -1112,6 +1313,45 @@ void Blas<platform::CPUDeviceContext>::BatchedGEMM( ...@@ -1112,6 +1313,45 @@ void Blas<platform::CPUDeviceContext>::BatchedGEMM(
} }
#endif #endif
} }
template <>
template <typename T>
void Blas<pten::CPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
CBLAS_TRANSPOSE transB, int M, int N,
int K, T alpha, const T *A, const T *B,
T beta, T *C, int batchCount,
int64_t strideA,
int64_t strideB) const {
PADDLE_ENFORCE_NOT_NULL(
A, platform::errors::InvalidArgument("Pointer A should not be null."));
PADDLE_ENFORCE_NOT_NULL(
B, platform::errors::InvalidArgument("Pointer B should not be null."));
PADDLE_ENFORCE_NOT_NULL(
C, platform::errors::InvalidArgument("Pointer C should not be null."));
#ifdef PADDLE_WITH_MKLML
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
auto a_array = std::vector<const T *>(batchCount);
auto b_array = std::vector<const T *>(batchCount);
auto c_array = std::vector<T *>(batchCount);
for (int k = 0; k < batchCount; ++k) {
a_array[k] = &A[k * strideA];
b_array[k] = &B[k * strideB];
c_array[k] = &C[k * M * N];
}
CBlas<T>::GEMM_BATCH(CblasRowMajor, &transA, &transB, &M, &N, &K, &alpha,
a_array.data(), &lda, b_array.data(), &ldb, &beta,
c_array.data(), &ldc, 1 /* group_count */, &batchCount);
#else
for (int k = 0; k < batchCount; ++k) {
auto *Ak = &A[k * strideA];
auto *Bk = &B[k * strideB];
auto *Ck = &C[k * M * N];
this->template GEMM<T>(transA, transB, M, N, K, alpha, Ak, Bk, beta, Ck);
}
#endif
}
template <> template <>
template <typename T> template <typename T>
...@@ -1132,6 +1372,27 @@ void Blas<platform::CPUDeviceContext>::BatchedGEMM( ...@@ -1132,6 +1372,27 @@ void Blas<platform::CPUDeviceContext>::BatchedGEMM(
} }
#endif #endif
} }
template <>
template <typename T>
void Blas<pten::CPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
CBLAS_TRANSPOSE transB, int M, int N,
int K, T alpha, const T **A,
const T **B, T beta, T **C,
int batchCount) const {
#ifdef PADDLE_WITH_MKLML
const int lda = (std::max)((transA == CblasNoTrans) ? K : M, 1);
const int ldb = (std::max)((transB == CblasNoTrans) ? N : K, 1);
const int ldc = (std::max)(N, 1);
CBlas<T>::GEMM_BATCH(CblasRowMajor, &transA, &transB, &M, &N, &K, &alpha, A,
&lda, B, &ldb, &beta, C, &ldc, 1 /* group_count */,
&batchCount);
#else
for (int k = 0; k < batchCount; ++k) {
this->template GEMM<T>(transA, transB, M, N, K, alpha, A[k], B[k], beta,
C[k]);
}
#endif
}
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) && \ #if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) && \
!defined(PADDLE_WITH_HIP) // @{ Group Blas MKLML: BatchedGEMMWithHead !defined(PADDLE_WITH_HIP) // @{ Group Blas MKLML: BatchedGEMMWithHead
...@@ -1204,6 +1465,75 @@ void Blas<platform::CPUDeviceContext>::BatchedGEMMWithHead( ...@@ -1204,6 +1465,75 @@ void Blas<platform::CPUDeviceContext>::BatchedGEMMWithHead(
} }
} }
} }
template <>
template <typename T>
void Blas<pten::CPUContext>::BatchedGEMMWithHead(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int W1, int H1, int W2,
int H2, T alpha, const T *A, const T *B, T beta, T *C, int batchCount,
int64_t strideA, int64_t strideB, int64_t head_number,
bool split_b_vertical) const {
int lda = (transA == CblasNoTrans) ? W1 : H1;
int ldb = (transB == CblasNoTrans) ? W2 : H2;
auto a_array = std::vector<const T *>(batchCount);
auto b_array = std::vector<const T *>(batchCount);
auto c_array = std::vector<T *>(batchCount);
if (split_b_vertical) {
int ldc = W2;
int sub_width = W2 / head_number;
for (int i = 0; i < head_number; i++) {
int sub_matA_offset = (transA == CblasNoTrans)
? i * (W1 / head_number)
: i * (W1 / head_number) * H1;
int sub_matB_offset = (transB == CblasNoTrans)
? i * (W2 / head_number)
: i * (W2 / head_number) * H2;
int sub_matC_offset = i * W2 / head_number;
for (int k = 0; k < batchCount; ++k) {
a_array[k] = &A[k * strideA] + sub_matA_offset;
b_array[k] = &B[k * strideB] + sub_matB_offset;
c_array[k] = &C[k * H1 * W2] + sub_matC_offset;
}
CBlas<T>::GEMM_BATCH(CblasRowMajor, &transA, &transB, &H1, &sub_width,
&H2, &alpha, a_array.data(), &lda, b_array.data(),
&ldb, &beta, c_array.data(), &ldc,
1 /* group_count */, &batchCount);
}
} else {
PADDLE_ENFORCE_EQ(
W1, H2,
platform::errors::InvalidArgument(
"The fisrt matrix width should be same as second matrix height,"
"but received fisrt matrix width %d"
", second matrix height %d",
W1, H2));
int ldc = W2 * head_number;
int sub_width = W1 / head_number;
for (int i = 0; i < head_number; i++) {
int sub_matA_offset = (transA == CblasNoTrans)
? i * (W1 / head_number)
: i * (W1 / head_number) * H1;
int sub_matB_offset = (transB == CblasNoTrans)
? i * (W1 / head_number) * W2
: i * (W1 / head_number);
int sub_matC_offset = i * W2;
for (int k = 0; k < batchCount; ++k) {
a_array[k] = &A[k * strideA] + sub_matA_offset;
b_array[k] = &B[k * strideB] + sub_matB_offset;
c_array[k] = &C[k * H1 * head_number * W2] + sub_matC_offset;
}
CBlas<T>::GEMM_BATCH(CblasRowMajor, &transA, &transB, &H1, &W2,
&sub_width, &alpha, a_array.data(), &lda,
b_array.data(), &ldb, &beta, c_array.data(), &ldc,
1 /* group_count */, &batchCount);
}
}
}
#endif // @} End Group Blas MKLML: BatchedGEMMWithHead #endif // @} End Group Blas MKLML: BatchedGEMMWithHead
template <typename DeviceContext> template <typename DeviceContext>
...@@ -1241,6 +1571,31 @@ void Blas<platform::CPUDeviceContext>::MatMul(const int M, const int N, ...@@ -1241,6 +1571,31 @@ void Blas<platform::CPUDeviceContext>::MatMul(const int M, const int N,
CBlas<T>::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K, CBlas<T>::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K,
static_cast<T>(1), A, K, B, N, static_cast<T>(0), C, N); static_cast<T>(1), A, K, B, N, static_cast<T>(0), C, N);
} }
template <>
template <typename T>
void Blas<pten::CPUContext>::MatMul(const int M, const int N, const int K,
const T *A, const T *B, T *C) const {
#ifdef PADDLE_WITH_LIBXSMM
// Refer to https://github.com/hfp/libxsmm/blob/master/README.md
// But the threshold is custom constexpr int LIBXSMM_THRESHOLD = 20 * 20 * 20;
// Since the matrix is very small,
// so the unit of calculation is already very fast,
// and the if( M*N*K < LIBXSMM_THRESHOLD) would be overhead,
// use xsmm directly.
// Note: SMM use ColMajor
const char transa = 'N';
const char transb = 'N';
const T alpha = static_cast<T>(1);
const T beta = static_cast<T>(0);
CBlas<T>::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &N, A, &K, &beta,
C, &N);
return;
#endif
CBlas<T>::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K,
static_cast<T>(1), A, K, B, N, static_cast<T>(0), C, N);
}
template <typename DeviceContext> template <typename DeviceContext>
template <typename T> template <typename T>
...@@ -1443,6 +1798,18 @@ void Blas<platform::CPUDeviceContext>::VMERF(int n, const T *a, T *y, ...@@ -1443,6 +1798,18 @@ void Blas<platform::CPUDeviceContext>::VMERF(int n, const T *a, T *y,
} }
#endif #endif
} }
template <>
template <typename T>
void Blas<pten::CPUContext>::VMERF(int n, const T *a, T *y,
int64_t mode) const {
#ifdef PADDLE_WITH_MKLML
CBlas<T>::VMERF(n, a, y, mode);
#else
for (int i = 0; i < n; ++i) {
y[i] = std::erf(a[i]);
}
#endif
}
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
template <> template <>
...@@ -1455,6 +1822,17 @@ void Blas<platform::CPUDeviceContext>::CSRMM( ...@@ -1455,6 +1822,17 @@ void Blas<platform::CPUDeviceContext>::CSRMM(
CBlas<T>::CSRMM(transa, m, n, k, alpha, matdescra, val, indx, pntrb, pntre, b, CBlas<T>::CSRMM(transa, m, n, k, alpha, matdescra, val, indx, pntrb, pntre, b,
ldb, beta, c, ldc); ldb, beta, c, ldc);
} }
template <>
template <typename T>
void Blas<pten::CPUContext>::CSRMM(const char *transa, const int *m,
const int *n, const int *k, const T *alpha,
const char *matdescra, const T *val,
const int *indx, const int *pntrb,
const int *pntre, const T *b, const int *ldb,
const T *beta, T *c, const int *ldc) const {
CBlas<T>::CSRMM(transa, m, n, k, alpha, matdescra, val, indx, pntrb, pntre, b,
ldb, beta, c, ldc);
}
#endif #endif
template <> template <>
...@@ -1467,6 +1845,15 @@ void Blas<platform::CPUDeviceContext>::TRSM(CBLAS_SIDE side, CBLAS_UPLO uplo, ...@@ -1467,6 +1845,15 @@ void Blas<platform::CPUDeviceContext>::TRSM(CBLAS_SIDE side, CBLAS_UPLO uplo,
CBlas<T>::TRSM(CblasRowMajor, side, uplo, transA, diag, M, N, alpha, A, lda, CBlas<T>::TRSM(CblasRowMajor, side, uplo, transA, diag, M, N, alpha, A, lda,
B, ldb); B, ldb);
} }
template <>
template <typename T>
void Blas<pten::CPUContext>::TRSM(CBLAS_SIDE side, CBLAS_UPLO uplo,
CBLAS_TRANSPOSE transA, CBLAS_DIAG diag,
int M, int N, T alpha, const T *A, int lda,
T *B, int ldb) const {
CBlas<T>::TRSM(CblasRowMajor, side, uplo, transA, diag, M, N, alpha, A, lda,
B, ldb);
}
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -44,6 +44,22 @@ template struct SetConstant<platform::CUDADeviceContext, ...@@ -44,6 +44,22 @@ template struct SetConstant<platform::CUDADeviceContext,
template struct SetConstant<platform::CUDADeviceContext, template struct SetConstant<platform::CUDADeviceContext,
platform::complex<double>>; platform::complex<double>>;
template struct SetConstant<platform::CUDAPinnedDeviceContext,
platform::float16>;
template struct SetConstant<platform::CUDAPinnedDeviceContext,
platform::bfloat16>;
template struct SetConstant<platform::CUDAPinnedDeviceContext, float>;
template struct SetConstant<platform::CUDAPinnedDeviceContext, double>;
template struct SetConstant<platform::CUDAPinnedDeviceContext, uint8_t>;
template struct SetConstant<platform::CUDAPinnedDeviceContext, int>;
template struct SetConstant<platform::CUDAPinnedDeviceContext, int16_t>;
template struct SetConstant<platform::CUDAPinnedDeviceContext, int64_t>;
template struct SetConstant<platform::CUDAPinnedDeviceContext, bool>;
template struct SetConstant<platform::CUDAPinnedDeviceContext,
platform::complex<float>>;
template struct SetConstant<platform::CUDAPinnedDeviceContext,
platform::complex<double>>;
#define DEFINE_GPU_TRANS(RANK) \ #define DEFINE_GPU_TRANS(RANK) \
template struct Transpose<platform::CUDADeviceContext, bool, RANK>; \ template struct Transpose<platform::CUDADeviceContext, bool, RANK>; \
template struct Transpose<platform::CUDADeviceContext, float, RANK>; \ template struct Transpose<platform::CUDADeviceContext, float, RANK>; \
......
...@@ -53,7 +53,10 @@ class MatMulV2Kernel : public framework::OpKernel<T> { ...@@ -53,7 +53,10 @@ class MatMulV2Kernel : public framework::OpKernel<T> {
Out->mutable_data<T>(X->place()); Out->mutable_data<T>(X->place());
// call new kernel // call new kernel
pten::MatmulKernel<T>(dev_ctx, *X, *Y, trans_x, trans_y, Out); pten::MatmulKernel<T>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
*X, *Y, trans_x, trans_y, Out);
} }
}; };
...@@ -149,8 +152,10 @@ class MatMulV2GradKernel : public framework::OpKernel<T> { ...@@ -149,8 +152,10 @@ class MatMulV2GradKernel : public framework::OpKernel<T> {
auto& dev_ctx = ctx.device_context<DeviceContext>(); auto& dev_ctx = ctx.device_context<DeviceContext>();
// call new kernel // call new kernel
pten::MatmulGradKernel<T>(dev_ctx, *x, *y, *dout, transpose_x, transpose_y, pten::MatmulGradKernel<T>(
dx, dy); static_cast<const typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
*x, *y, *dout, transpose_x, transpose_y, dx, dy);
} }
}; };
...@@ -178,8 +183,10 @@ class MatMulV2DoubleGradKernel : public framework::OpKernel<T> { ...@@ -178,8 +183,10 @@ class MatMulV2DoubleGradKernel : public framework::OpKernel<T> {
auto& dev_ctx = context.device_context<DeviceContext>(); auto& dev_ctx = context.device_context<DeviceContext>();
// call new kernel // call new kernel
pten::MatmulDoubleGradKernel<T>(dev_ctx, *x, *y, *dout, *ddx, *ddy, pten::MatmulDoubleGradKernel<T>(
transpose_x, transpose_y, dx, dy, ddout); static_cast<const typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
*x, *y, *dout, *ddx, *ddy, transpose_x, transpose_y, dx, dy, ddout);
} }
}; };
...@@ -218,7 +225,9 @@ class MatMulV2TripleGradKernel : public framework::OpKernel<T> { ...@@ -218,7 +225,9 @@ class MatMulV2TripleGradKernel : public framework::OpKernel<T> {
auto& dev_ctx = context.device_context<DeviceContext>(); auto& dev_ctx = context.device_context<DeviceContext>();
// call new kernel // call new kernel
pten::MatmulTripleGradKernel<T>( pten::MatmulTripleGradKernel<T>(
dev_ctx, *x, *y, *dout, *ddx, *ddy, *d_dx, *d_dy, *d_ddout, transpose_x, static_cast<const typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
*x, *y, *dout, *ddx, *ddy, *d_dx, *d_dy, *d_ddout, transpose_x,
transpose_y, out_d_x, out_d_y, out_d_dout, out_d_ddx, out_d_ddy); transpose_y, out_d_x, out_d_y, out_d_dout, out_d_ddx, out_d_ddy);
} }
}; };
......
...@@ -16,12 +16,6 @@ limitations under the License. */ ...@@ -16,12 +16,6 @@ limitations under the License. */
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace pten { namespace pten {
class DenseTensor; class DenseTensor;
} // namespace pten } // namespace pten
......
...@@ -17,12 +17,6 @@ limitations under the License. */ ...@@ -17,12 +17,6 @@ limitations under the License. */
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/stream/stream.h" #include "paddle/fluid/platform/stream/stream.h"
namespace paddle {
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace pten { namespace pten {
class DenseTensor; class DenseTensor;
} // namespace pten } // namespace pten
......
...@@ -19,12 +19,6 @@ limitations under the License. */ ...@@ -19,12 +19,6 @@ limitations under the License. */
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace pten { namespace pten {
class DenseTensor; class DenseTensor;
} // namespace pten } // namespace pten
......
...@@ -45,9 +45,6 @@ class Executor; ...@@ -45,9 +45,6 @@ class Executor;
class ProgramDesc; class ProgramDesc;
class Scope; class Scope;
} // namespace framework } // namespace framework
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle } // namespace paddle
namespace paddle { namespace paddle {
......
...@@ -22,12 +22,6 @@ limitations under the License. */ ...@@ -22,12 +22,6 @@ limitations under the License. */
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -258,8 +258,11 @@ class ReduceKernel : public framework::OpKernel<T> { ...@@ -258,8 +258,11 @@ class ReduceKernel : public framework::OpKernel<T> {
std::vector<int64_t> tmp_dims(dims.begin(), dims.end()); std::vector<int64_t> tmp_dims(dims.begin(), dims.end());
// call new kernel // call new kernel
pten::Reduce<DeviceContext, T, Functor>( pten::Reduce<typename framework::ConvertToPtenContext<DeviceContext>::TYPE,
dev_ctx, *pt_x.get(), reduce_all, tmp_dims, keep_dim, T, Functor>(
static_cast<const typename framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
*pt_x.get(), reduce_all, tmp_dims, keep_dim,
pten::TransToPtenDataType(cast_out_dtype), pt_out.get()); pten::TransToPtenDataType(cast_out_dtype), pt_out.get());
} }
}; };
......
...@@ -19,6 +19,7 @@ limitations under the License. */ ...@@ -19,6 +19,7 @@ limitations under the License. */
// only can include the headers in paddle/pten/api dirs // only can include the headers in paddle/pten/api dirs
#include "paddle/pten/api/lib/utils/tensor_utils.h" #include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/common/scalar_array.h" #include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/kernels/reshape_grad_kernel.h" #include "paddle/pten/kernels/reshape_grad_kernel.h"
#include "paddle/pten/kernels/reshape_kernel.h" #include "paddle/pten/kernels/reshape_kernel.h"
...@@ -435,7 +436,8 @@ class ReshapeKernel { ...@@ -435,7 +436,8 @@ class ReshapeKernel {
} }
if (platform::is_cpu_place(ctx.GetPlace())) { if (platform::is_cpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CPUDeviceContext>(); auto &dev_ctx = ctx.device_context<platform::CPUDeviceContext>();
pten::ReshapeKernel(dev_ctx, *pt_x.get(), pt_scalar_shape, pt_out); pten::ReshapeKernel(static_cast<const pten::CPUContext &>(dev_ctx),
*pt_x.get(), pt_scalar_shape, pt_out);
} }
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::is_gpu_place(ctx.GetPlace())) {
...@@ -471,7 +473,8 @@ class ReshapeGradKernel { ...@@ -471,7 +473,8 @@ class ReshapeGradKernel {
if (platform::is_cpu_place(ctx.GetPlace())) { if (platform::is_cpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CPUDeviceContext>(); auto &dev_ctx = ctx.device_context<platform::CPUDeviceContext>();
pten::ReshapeGradKernel(dev_ctx, *pt_d_out.get(), pt_d_x.get()); pten::ReshapeGradKernel(static_cast<const pten::CPUContext &>(dev_ctx),
*pt_d_out.get(), pt_d_x.get());
} }
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::is_gpu_place(ctx.GetPlace())) {
...@@ -500,7 +503,9 @@ class ReshapeDoubleGradKernel { ...@@ -500,7 +503,9 @@ class ReshapeDoubleGradKernel {
if (platform::is_cpu_place(ctx.GetPlace())) { if (platform::is_cpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CPUDeviceContext>(); auto &dev_ctx = ctx.device_context<platform::CPUDeviceContext>();
pten::ReshapeDoubleGradKernel(dev_ctx, *pt_dd_x.get(), pt_dd_out.get()); pten::ReshapeDoubleGradKernel(
static_cast<const pten::CPUContext &>(dev_ctx), *pt_dd_x.get(),
pt_dd_out.get());
} }
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::is_gpu_place(ctx.GetPlace())) {
......
...@@ -67,7 +67,10 @@ class ScaleKernel : public framework::OpKernel<T> { ...@@ -67,7 +67,10 @@ class ScaleKernel : public framework::OpKernel<T> {
auto& dev_ctx = ctx.device_context<DeviceContext>(); auto& dev_ctx = ctx.device_context<DeviceContext>();
// call new kernel // call new kernel
pten::ScaleKernel<T>(dev_ctx, *in, scale, bias, bias_after_scale, out); pten::ScaleKernel<T>(
static_cast<const typename framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
*in, scale, bias, bias_after_scale, out);
} }
}; };
......
...@@ -30,7 +30,7 @@ class GPUSeedKernel : public framework::OpKernel<T> { ...@@ -30,7 +30,7 @@ class GPUSeedKernel : public framework::OpKernel<T> {
if (cpu_place) { if (cpu_place) {
platform::DeviceContextPool &pool = platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance(); platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(context.GetPlace()); auto &dev_ctx = *pool.Get(platform::CPUPlace());
out->mutable_data<T>(platform::CPUPlace()); out->mutable_data<T>(platform::CPUPlace());
math::SetConstant<platform::CPUDeviceContext, T> functor; math::SetConstant<platform::CPUDeviceContext, T> functor;
functor(reinterpret_cast<const platform::CPUDeviceContext &>(dev_ctx), functor(reinterpret_cast<const platform::CPUDeviceContext &>(dev_ctx),
......
...@@ -35,7 +35,11 @@ class SignKernel : public framework::OpKernel<T> { ...@@ -35,7 +35,11 @@ class SignKernel : public framework::OpKernel<T> {
out->mutable_data<T>(x->place()); out->mutable_data<T>(x->place());
// call new kernel // call new kernel
pten::SignKernel<T, DeviceContext>(dev_ctx, *x, out); pten::SignKernel<T, typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
*x, out);
} }
}; };
......
...@@ -21,12 +21,6 @@ ...@@ -21,12 +21,6 @@
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace pten { namespace pten {
class DenseTensor; class DenseTensor;
} // namespace pten } // namespace pten
......
...@@ -123,7 +123,7 @@ cc_library(cudnn_workspace_helper SRCS cudnn_workspace_helper.cc DEPS boost) ...@@ -123,7 +123,7 @@ cc_library(cudnn_workspace_helper SRCS cudnn_workspace_helper.cc DEPS boost)
# avoiding cycle dependencies # avoiding cycle dependencies
cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool malloc xxhash ${STREAM_CALLBACK_DEPS} cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool malloc xxhash ${STREAM_CALLBACK_DEPS}
place pten_place eigen3 stringpiece cpu_helper cpu_info framework_proto ${IPU_CTX_DEPS} ${GPU_CTX_DEPS} ${NPU_CTX_DEPS} ${MKLDNN_CTX_DEPS} place pten_place eigen3 stringpiece cpu_helper cpu_info framework_proto ${IPU_CTX_DEPS} ${GPU_CTX_DEPS} ${NPU_CTX_DEPS} ${MKLDNN_CTX_DEPS}
${dgc_deps} dlpack cudnn_workspace_helper ${XPU_CTX_DEPS} ${MLU_CTX_DEPS}) ${dgc_deps} dlpack cudnn_workspace_helper ${XPU_CTX_DEPS} ${MLU_CTX_DEPS} cpu_context)
cc_library(collective_helper SRCS collective_helper.cc gen_comm_id_helper.cc DEPS framework_proto device_context enforce) cc_library(collective_helper SRCS collective_helper.cc gen_comm_id_helper.cc DEPS framework_proto device_context enforce)
if(WITH_ASCEND_CL) if(WITH_ASCEND_CL)
......
...@@ -21,8 +21,6 @@ struct DefaultDevice; ...@@ -21,8 +21,6 @@ struct DefaultDevice;
struct GpuDevice; struct GpuDevice;
} // namespace Eigen } // namespace Eigen
// class DeviceContext;
namespace paddle { namespace paddle {
namespace platform { namespace platform {
......
...@@ -132,7 +132,7 @@ platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) { ...@@ -132,7 +132,7 @@ platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
return it->second.get().get(); return it->second.get().get();
} }
template <typename DevCtx, typename PlaceType> template <typename DevCtx>
inline void EmplaceDeviceContext( inline void EmplaceDeviceContext(
std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>* std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
map_ptr, map_ptr,
...@@ -158,19 +158,14 @@ DeviceContextPool::DeviceContextPool( ...@@ -158,19 +158,14 @@ DeviceContextPool::DeviceContextPool(
} }
for (auto& p : set) { for (auto& p : set) {
if (platform::is_cpu_place(p)) { if (platform::is_cpu_place(p)) {
platform::CPUPlace place;
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
EmplaceDeviceContext<MKLDNNDeviceContext, CPUPlace>(&device_contexts_, EmplaceDeviceContext<MKLDNNDeviceContext>(&device_contexts_, p);
place);
#else #else
EmplaceDeviceContext<CPUDeviceContext, CPUPlace>(&device_contexts_, EmplaceDeviceContext<CPUDeviceContext>(&device_contexts_, p);
place);
#endif #endif
} else if (platform::is_gpu_place(p)) { } else if (platform::is_gpu_place(p)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
platform::CUDAPlace place(p.GetDeviceId()); EmplaceDeviceContext<CUDADeviceContext>(&device_contexts_, p);
EmplaceDeviceContext<CUDADeviceContext, CUDAPlace>(&device_contexts_,
place);
#else #else
PADDLE_THROW( PADDLE_THROW(
platform::errors::Unimplemented("CUDAPlace is not supported. Please " platform::errors::Unimplemented("CUDAPlace is not supported. Please "
...@@ -178,9 +173,7 @@ DeviceContextPool::DeviceContextPool( ...@@ -178,9 +173,7 @@ DeviceContextPool::DeviceContextPool(
#endif #endif
} else if (platform::is_cuda_pinned_place(p)) { } else if (platform::is_cuda_pinned_place(p)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
platform::CUDAPinnedPlace place; EmplaceDeviceContext<CUDAPinnedDeviceContext>(&device_contexts_, p);
EmplaceDeviceContext<CUDAPinnedDeviceContext, CUDAPinnedPlace>(
&device_contexts_, place);
#else #else
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"CUDAPlace is not supported. Please re-compile with WITH_GPU " "CUDAPlace is not supported. Please re-compile with WITH_GPU "
...@@ -188,9 +181,7 @@ DeviceContextPool::DeviceContextPool( ...@@ -188,9 +181,7 @@ DeviceContextPool::DeviceContextPool(
#endif #endif
} else if (platform::is_xpu_place(p)) { } else if (platform::is_xpu_place(p)) {
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
platform::XPUPlace place(p.GetDeviceId()); EmplaceDeviceContext<XPUDeviceContext>(&device_contexts_, p);
EmplaceDeviceContext<XPUDeviceContext, XPUPlace>(&device_contexts_,
place);
#else #else
PADDLE_THROW( PADDLE_THROW(
platform::errors::Unimplemented("XPUPlace is not supported. Please " platform::errors::Unimplemented("XPUPlace is not supported. Please "
...@@ -198,9 +189,7 @@ DeviceContextPool::DeviceContextPool( ...@@ -198,9 +189,7 @@ DeviceContextPool::DeviceContextPool(
#endif #endif
} else if (platform::is_mlu_place(p)) { } else if (platform::is_mlu_place(p)) {
#ifdef PADDLE_WITH_MLU #ifdef PADDLE_WITH_MLU
platform::MLUPlace place(p.GetDeviceId()); EmplaceDeviceContext<MLUDeviceContext>(&device_contexts_, p);
EmplaceDeviceContext<MLUDeviceContext, MLUPlace>(&device_contexts_,
place);
#else #else
PADDLE_THROW( PADDLE_THROW(
platform::errors::Unimplemented("MLUPlace is not supported. Please " platform::errors::Unimplemented("MLUPlace is not supported. Please "
...@@ -208,9 +197,7 @@ DeviceContextPool::DeviceContextPool( ...@@ -208,9 +197,7 @@ DeviceContextPool::DeviceContextPool(
#endif #endif
} else if (platform::is_ipu_place(p)) { } else if (platform::is_ipu_place(p)) {
#ifdef PADDLE_WITH_IPU #ifdef PADDLE_WITH_IPU
platform::IPUPlace place(p.GetDeviceId()); EmplaceDeviceContext<IPUDeviceContext>(&device_contexts_, p);
EmplaceDeviceContext<IPUDeviceContext, IPUPlace>(&device_contexts_,
place);
#else #else
PADDLE_THROW( PADDLE_THROW(
platform::errors::Unimplemented("IPUPlace is not supported. Please " platform::errors::Unimplemented("IPUPlace is not supported. Please "
...@@ -218,9 +205,7 @@ DeviceContextPool::DeviceContextPool( ...@@ -218,9 +205,7 @@ DeviceContextPool::DeviceContextPool(
#endif #endif
} else if (platform::is_npu_place(p)) { } else if (platform::is_npu_place(p)) {
#ifdef PADDLE_WITH_ASCEND_CL #ifdef PADDLE_WITH_ASCEND_CL
platform::NPUPlace place(p.GetDeviceId()); EmplaceDeviceContext<NPUDeviceContext>(&device_contexts_, p);
EmplaceDeviceContext<NPUDeviceContext, NPUPlace>(&device_contexts_,
place);
#else #else
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"NPUPlace is not supported. Please " "NPUPlace is not supported. Please "
...@@ -228,9 +213,7 @@ DeviceContextPool::DeviceContextPool( ...@@ -228,9 +213,7 @@ DeviceContextPool::DeviceContextPool(
#endif #endif
} else if (platform::is_npu_pinned_place(p)) { } else if (platform::is_npu_pinned_place(p)) {
#ifdef PADDLE_WITH_ASCEND_CL #ifdef PADDLE_WITH_ASCEND_CL
platform::NPUPinnedPlace place; EmplaceDeviceContext<NPUPinnedDeviceContext>(&device_contexts_, p);
EmplaceDeviceContext<NPUPinnedDeviceContext, NPUPinnedPlace>(
&device_contexts_, place);
#else #else
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"NPUPinnedPlace is not supported. Please re-compile with " "NPUPinnedPlace is not supported. Please re-compile with "
...@@ -241,19 +224,9 @@ DeviceContextPool::DeviceContextPool( ...@@ -241,19 +224,9 @@ DeviceContextPool::DeviceContextPool(
} }
} }
CPUDeviceContext::CPUDeviceContext() { CPUDeviceContext::CPUDeviceContext() : pten::CPUContext() {}
eigen_device_.reset(new Eigen::DefaultDevice());
}
CPUDeviceContext::CPUDeviceContext(CPUPlace place) : place_(place) {
eigen_device_.reset(new Eigen::DefaultDevice());
}
Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const {
return eigen_device_.get();
}
Place CPUDeviceContext::GetPlace() const { return place_; } CPUDeviceContext::CPUDeviceContext(CPUPlace place) : pten::CPUContext() {}
#ifdef PADDLE_WITH_IPU #ifdef PADDLE_WITH_IPU
IPUDeviceContext::IPUDeviceContext(IPUPlace place) : place_(place) { IPUDeviceContext::IPUDeviceContext(IPUPlace place) : place_(place) {
......
...@@ -18,6 +18,9 @@ limitations under the License. */ ...@@ -18,6 +18,9 @@ limitations under the License. */
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/core/device_context.h"
#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/malloc.h"
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/device/gpu/gpu_helper.h" #include "paddle/fluid/platform/device/gpu/gpu_helper.h"
...@@ -117,26 +120,15 @@ constexpr DeviceType kNPU = DeviceType::NPU; ...@@ -117,26 +120,15 @@ constexpr DeviceType kNPU = DeviceType::NPU;
constexpr DeviceType kIPU = DeviceType::IPU; constexpr DeviceType kIPU = DeviceType::IPU;
constexpr DeviceType kMLU = DeviceType::MLU; constexpr DeviceType kMLU = DeviceType::MLU;
class DeviceContext { using DeviceContext = pten::DeviceContext;
public:
virtual ~DeviceContext() PADDLE_MAY_THROW {}
virtual Place GetPlace() const = 0;
virtual void Wait() const {}
};
class CPUDeviceContext : public DeviceContext { // using CPUDeviceContext = pten::CPUContext;
// TODO(wilber): The place constructor is used in many places, it is more
// difficult to use CPUDeviceContext = pten::CPUContext directly.
class CPUDeviceContext : public pten::CPUContext {
public: public:
CPUDeviceContext(); CPUDeviceContext();
explicit CPUDeviceContext(CPUPlace place); explicit CPUDeviceContext(CPUPlace place);
Eigen::DefaultDevice* eigen_device() const;
Place GetPlace() const override;
private:
CPUPlace place_;
std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
}; };
template <typename Place> template <typename Place>
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" #include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/pten/backends/cpu/cpu_context.h"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
...@@ -27,6 +28,7 @@ struct ForRange { ...@@ -27,6 +28,7 @@ struct ForRange {
void operator()(Function func) const; void operator()(Function func) const;
}; };
// NOTE: After the pten kernel is migrated, it needs to be deleted.
template <> template <>
struct ForRange<CPUDeviceContext> { struct ForRange<CPUDeviceContext> {
ForRange(const CPUDeviceContext& dev_ctx, size_t limit) : limit_(limit) {} ForRange(const CPUDeviceContext& dev_ctx, size_t limit) : limit_(limit) {}
...@@ -41,6 +43,20 @@ struct ForRange<CPUDeviceContext> { ...@@ -41,6 +43,20 @@ struct ForRange<CPUDeviceContext> {
size_t limit_; size_t limit_;
}; };
template <>
struct ForRange<pten::CPUContext> {
ForRange(const pten::CPUContext& dev_ctx, size_t limit) : limit_(limit) {}
template <typename Function>
void operator()(Function func) const {
for (size_t i = 0; i < limit_; ++i) {
func(i);
}
}
size_t limit_;
};
#if defined(__NVCC__) || defined(__HIPCC__) #if defined(__NVCC__) || defined(__HIPCC__)
template <typename Function> template <typename Function>
__global__ static void ForRangeElemwiseOpGridIsOne(Function func) { __global__ static void ForRangeElemwiseOpGridIsOne(Function func) {
......
...@@ -59,6 +59,7 @@ struct Transform { ...@@ -59,6 +59,7 @@ struct Transform {
BinaryOperation op); BinaryOperation op);
}; };
// NOTE: After the pten kernel is migrated, it needs to be deleted.
template <> template <>
struct Transform<platform::CPUDeviceContext> { struct Transform<platform::CPUDeviceContext> {
template <typename InputIter, typename OutputIter, typename UnaryOperation> template <typename InputIter, typename OutputIter, typename UnaryOperation>
...@@ -76,6 +77,23 @@ struct Transform<platform::CPUDeviceContext> { ...@@ -76,6 +77,23 @@ struct Transform<platform::CPUDeviceContext> {
} }
}; };
template <>
struct Transform<pten::CPUContext> {
template <typename InputIter, typename OutputIter, typename UnaryOperation>
void operator()(const pten::CPUContext& context, InputIter first,
InputIter last, OutputIter result, UnaryOperation op) {
std::transform(first, last, result, op);
}
template <typename InputIter1, typename InputIter2, typename OutputIter,
typename BinaryOperation>
void operator()(const pten::CPUContext& context, InputIter1 first1,
InputIter1 last1, InputIter2 first2, OutputIter result,
BinaryOperation op) {
std::transform(first1, last1, first2, result, op);
}
};
#if defined(__NVCC__) || defined(__HIPCC__) #if defined(__NVCC__) || defined(__HIPCC__)
template <> template <>
struct Transform<platform::CUDADeviceContext> { struct Transform<platform::CUDADeviceContext> {
......
add_subdirectory(cpu)
cc_library(pten_context SRCS all_context.cc DEPS device_context) cc_library(pten_context SRCS all_context.cc DEPS device_context)
...@@ -24,7 +24,9 @@ limitations under the License. */ ...@@ -24,7 +24,9 @@ limitations under the License. */
#include "paddle/pten/backends/gpu/gpu_context.h" #include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/backends/xpu/xpu_context.h" #include "paddle/pten/backends/xpu/xpu_context.h"
// TODO(wilber): DeviceContextPool nees include fluid file.
#include "paddle/fluid/platform/device_context.h"
namespace pten { namespace pten {
using DeviceContext = paddle::platform::DeviceContext;
using DeviceContextPool = paddle::platform::DeviceContextPool; using DeviceContextPool = paddle::platform::DeviceContextPool;
} // namespace pten } // namespace pten
if(WITH_MKLDNN)
# TODO(wilber): support mkldnn context.
cc_library(cpu_context SRCS cpu_context.cc DEPS pten_device_context mkldnn)
else()
cc_library(cpu_context SRCS cpu_context.cc DEPS pten_device_context)
endif()
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/api/ext/exception.h"
// NOTE: The paddle framework should add WITH_EIGEN option to support compile
// without eigen.
#include "paddle/pten/core/device_context.h"
#include "unsupported/Eigen/CXX11/Tensor"
namespace pten {
struct CPUContext::CPUImpl {
Eigen::DefaultDevice* device_{nullptr};
CPUContextResource res_;
CPUPlace place_;
CPUImpl() { device_ = new Eigen::DefaultDevice(); }
// Users need to manage external resources.
explicit CPUImpl(const CPUContextResource& ctx_res) : res_(ctx_res) {
device_ = res_.device;
}
~CPUImpl() {
if (res_.device == nullptr) {
delete device_;
device_ = nullptr;
}
}
Eigen::DefaultDevice* GetEigenDevice() const {
PD_CHECK(device_ != nullptr, "the eigen_device is nullptr.");
return device_;
}
void SetEigenDevice(Eigen::DefaultDevice* device) {
if (device == nullptr) {
return;
}
res_.device = device;
device_ = device;
}
Place GetPlace() const { return place_; }
};
CPUContext::CPUContext() : DeviceContext(), cpu_impl_(nullptr) {
cpu_impl_ = std::make_unique<CPUImpl>();
}
CPUContext::CPUContext(const CPUContext& other)
: DeviceContext(), cpu_impl_(nullptr) {
cpu_impl_ = std::make_unique<CPUImpl>();
cpu_impl_->SetEigenDevice(other.eigen_device());
}
CPUContext::CPUContext(CPUContext&& other)
: DeviceContext(), cpu_impl_(nullptr) {
cpu_impl_ = std::move(other.cpu_impl_);
}
CPUContext::~CPUContext() = default;
CPUContext::CPUContext(const CPUContextResource& ctx_res)
: DeviceContext(), cpu_impl_(nullptr) {
cpu_impl_ = std::make_unique<CPUImpl>(ctx_res);
}
Eigen::DefaultDevice* CPUContext::eigen_device() const {
return cpu_impl_->GetEigenDevice();
}
void CPUContext::SetEigenDevice(Eigen::DefaultDevice* device) {
cpu_impl_->SetEigenDevice(device);
}
Place CPUContext::GetPlace() const { return cpu_impl_->GetPlace(); }
} // namespace pten
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -14,9 +14,47 @@ limitations under the License. */ ...@@ -14,9 +14,47 @@ limitations under the License. */
#pragma once #pragma once
// See Note [ Why still include the fluid headers? ] #include <memory>
#include "paddle/fluid/platform/device_context.h"
#include "paddle/pten/backends/cpu/forwards.h"
#include "paddle/pten/core/device_context.h"
// TODO(wilber): Do we need to use place in pten kernel?
#include "paddle/pten/common/place.h"
namespace pten { namespace pten {
using CPUContext = paddle::platform::CPUDeviceContext;
struct CPUContextResource {
Eigen::DefaultDevice* device{nullptr};
};
class CPUContext : public DeviceContext {
public:
// NOTE: DeviceContext hold resources. Used in training scenarios.
CPUContext();
// NOTE: Share the same underlying resources, please ensure that resources are
// not released.
CPUContext(const CPUContext&);
CPUContext(CPUContext&&);
~CPUContext();
Eigen::DefaultDevice* eigen_device() const;
// TODO(wilber): Whether the interface should be preserved.
Place GetPlace() const override;
public:
// NOTE: External users manage resources. Used in inference scenarios.
explicit CPUContext(const CPUContextResource& ctx_res);
void SetEigenDevice(Eigen::DefaultDevice* device);
private:
struct CPUImpl;
std::unique_ptr<CPUImpl> cpu_impl_;
};
} // namespace pten } // namespace pten
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// Forward-declares.
#pragma once
// Forward declaration of Eigen DefaultDevice types.
namespace Eigen {
struct DefaultDevice;
} // namespace Eigen
...@@ -13,8 +13,10 @@ cc_library(arg_map_context SRCS arg_map_context.cc DEPS enforce) ...@@ -13,8 +13,10 @@ cc_library(arg_map_context SRCS arg_map_context.cc DEPS enforce)
cc_library(tensor_base SRCS tensor_base.cc allocator.cc storage.cc DEPS enforce) cc_library(tensor_base SRCS tensor_base.cc allocator.cc storage.cc DEPS enforce)
cc_library(tensor_meta SRCS tensor_meta.cc DEPS enforce mixed_vector) cc_library(tensor_meta SRCS tensor_meta.cc DEPS enforce mixed_vector)
cc_library(dense_tensor SRCS dense_tensor.cc DEPS convert_utils tensor_meta tensor_base) cc_library(dense_tensor SRCS dense_tensor.cc DEPS convert_utils tensor_meta tensor_base)
cc_library(pten_device_context SRCS device_context.cc DEPS tensor_base )
# Will remove once we implemented MKLDNN_Tensor # Will remove once we implemented MKLDNN_Tensor
if(WITH_MKLDNN) if(WITH_MKLDNN)
add_dependencies(dense_tensor mkldnn) add_dependencies(dense_tensor mkldnn)
add_dependencies(tensor_base mkldnn)
endif() endif()
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/core/device_context.h"
namespace pten {
struct DeviceContext::Impl {
Allocator* allocator_{nullptr};
Impl() = default;
~Impl() = default;
void SetAllocator(Allocator* allocator) { allocator_ = allocator; }
const Allocator& GetAllocator() const { return *allocator_; }
// TODO(Wilber): Add impl. It seems that tensorbase not have interface to
// communicate with allocator.
void Alloc(TensorBase* tensor) {}
};
DeviceContext::DeviceContext() { impl_ = std::make_unique<Impl>(); }
DeviceContext::DeviceContext(const DeviceContext& other) {
impl_->SetAllocator(const_cast<Allocator*>(&other.GetAllocator()));
}
DeviceContext::DeviceContext(DeviceContext&& other) {
impl_ = std::move(other.impl_);
}
DeviceContext::~DeviceContext() = default;
void DeviceContext::SetAllocator(Allocator* allocator) {
impl_->SetAllocator(allocator);
}
const Allocator& DeviceContext::GetAllocator() const {
return impl_->GetAllocator();
}
void DeviceContext::Alloc(TensorBase* tensor) { impl_->Alloc(tensor); }
} // namespace pten
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
// TODO(wilber): Do we need to use place in pten kernel?
#include "paddle/pten/common/place.h"
#include "paddle/pten/core/candidate/allocator.h"
namespace pten {
class TensorBase;
/**
* DeviceContext provides device-related interfaces.
*
* All kernels must access the interfaces provided by the backend through
* DeviceContext.
*/
class DeviceContext {
public:
/**
* @brief Default construct.
*/
DeviceContext();
/**
* @brief Copy construct.
*/
DeviceContext(const DeviceContext&);
/**
* @brief Move construct.
*/
DeviceContext(DeviceContext&&);
/**
* @brief Default destruct.
*/
virtual ~DeviceContext();
/**
* @brief Set the deveice-releated Allocator object.
*
* @param allocator
*/
void SetAllocator(Allocator*);
/**
* @brief Get the const Allocator object.
*
* @return Allocator
*/
const Allocator& GetAllocator() const;
/**
* @brief Allocate memory for tensor.
*/
void Alloc(pten::TensorBase*);
// TODO(wilber): Just for the convenience of migrating the code, it will be
// modified or removed later.
virtual Place GetPlace() const = 0;
// TODO(wilber): The fluid framework uses wait() in many places, how to delete
// this API interface.
virtual void Wait() const {}
private:
struct Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace pten
...@@ -3,3 +3,4 @@ cc_test(test_intrusive_ptr SRCS test_intrusive_ptr.cc) ...@@ -3,3 +3,4 @@ cc_test(test_intrusive_ptr SRCS test_intrusive_ptr.cc)
cc_test(test_type_info SRCS test_type_info.cc) cc_test(test_type_info SRCS test_type_info.cc)
cc_test(test_convert_utils SRCS test_convert_utils.cc DEPS convert_utils) cc_test(test_convert_utils SRCS test_convert_utils.cc DEPS convert_utils)
cc_test(test_kernel_factory SRCS test_kernel_factory.cc DEPS kernel_factory scale_kernel) cc_test(test_kernel_factory SRCS test_kernel_factory.cc DEPS kernel_factory scale_kernel)
cc_test(test_pten_device_context SRCS test_device_context.cc DEPS pten_context cpu_context)
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "gtest/gtest.h"
// TODO(wilber): will remove after the cpu, gpu context megre.
#include "paddle/pten/backends/cpu/cpu_context.h"
// #include "paddle/pten/backends/all_context.h"
// NOTE: The paddle framework should add WITH_EIGEN option to support compile
// without eigen.
#include "unsupported/Eigen/CXX11/Tensor"
namespace pten {
namespace tests {
TEST(DeviceContext, cpu_context) {
std::cout << "test training scenarios" << std::endl;
{
pten::CPUContext ctx;
CHECK(ctx.eigen_device() != nullptr);
}
std::cout << "test inference scenarios" << std::endl;
Eigen::DefaultDevice* device = new Eigen::DefaultDevice();
{
pten::CPUContextResource ctx_res{device};
pten::CPUContext ctx(ctx_res);
CHECK(ctx.eigen_device() != nullptr);
}
{
pten::CPUContextResource ctx_res{nullptr};
pten::CPUContext ctx(ctx_res);
ctx.SetEigenDevice(device);
CHECK(ctx.eigen_device() != nullptr);
}
delete device;
std::cout << "test copy constructor" << std::endl;
{
pten::CPUContext ctx1;
pten::CPUContext ctx2(ctx1);
CHECK_EQ(ctx1.eigen_device(), ctx2.eigen_device());
}
std::cout << "test move constructor" << std::endl;
{
pten::CPUContext ctx1 = pten::CPUContext();
auto* eigen_device1 = ctx1.eigen_device();
pten::CPUContext ctx2(std::move(ctx1));
auto* eigen_device2 = ctx2.eigen_device();
CHECK_EQ(eigen_device1, eigen_device2);
}
}
} // namespace tests
} // namespace pten
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <memory> #include <memory>
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/kernels/cast_kernel.h" #include "paddle/pten/kernels/cast_kernel.h"
#include "paddle/pten/api/lib/utils/allocator.h" #include "paddle/pten/api/lib/utils/allocator.h"
...@@ -44,16 +45,11 @@ TEST(DEV_API, cast) { ...@@ -44,16 +45,11 @@ TEST(DEV_API, cast) {
dense_x_data[i] = i * 1.0; dense_x_data[i] = i * 1.0;
sum += i * 1.0; sum += i * 1.0;
} }
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(paddle::platform::CPUPlace());
pten::CPUContext dev_ctx;
pten::DataType out_dtype = pten::DataType::FLOAT64; pten::DataType out_dtype = pten::DataType::FLOAT64;
// 2. test API // 2. test API
auto out = pten::Cast<float>( auto out = pten::Cast<float>(dev_ctx, dense_x, out_dtype);
*(static_cast<paddle::platform::CPUDeviceContext*>(dev_ctx)),
dense_x,
out_dtype);
// 3. check result // 3. check result
ASSERT_EQ(out.dims().size(), 2); ASSERT_EQ(out.dims().size(), 2);
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <memory> #include <memory>
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/kernels/complex_kernel.h" #include "paddle/pten/kernels/complex_kernel.h"
#include "paddle/pten/api/lib/utils/allocator.h" #include "paddle/pten/api/lib/utils/allocator.h"
...@@ -41,13 +42,10 @@ TEST(DEV_API, conj) { ...@@ -41,13 +42,10 @@ TEST(DEV_API, conj) {
dense_x_data[i] = paddle::complex64(i * 1.0, i * 1.0); dense_x_data[i] = paddle::complex64(i * 1.0, i * 1.0);
} }
paddle::platform::DeviceContextPool& pool = pten::CPUContext dev_ctx;
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(paddle::platform::CPUPlace());
// 2. test API // 2. test API
auto out = pten::Conj<paddle::complex64>( auto out = pten::Conj<paddle::complex64>(dev_ctx, dense_x);
*(static_cast<paddle::platform::CPUDeviceContext*>(dev_ctx)), dense_x);
// 3. check result // 3. check result
ASSERT_EQ(out.dims().size(), 2); ASSERT_EQ(out.dims().size(), 2);
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <memory> #include <memory>
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/core/kernel_registry.h" #include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/copy_kernel.h" #include "paddle/pten/kernels/copy_kernel.h"
...@@ -54,9 +55,8 @@ TEST(DEV_API, copy) { ...@@ -54,9 +55,8 @@ TEST(DEV_API, copy) {
const auto& a = paddle::platform::CPUPlace(); const auto& a = paddle::platform::CPUPlace();
std::cout << typeid(a).name() << std::endl; std::cout << typeid(a).name() << std::endl;
// 2. test API // 2. test API
auto& pool = paddle::platform::DeviceContextPool::Instance(); pten::CPUContext dev_ctx;
auto* dev_ctx = pool.GetByPlace(paddle::platform::CPUPlace()); pten::Copy(dev_ctx, *(dense_src.get()), false, dense_dst.get());
pten::Copy(*dev_ctx, *(dense_src.get()), false, dense_dst.get());
// 3. check result // 3. check result
for (int64_t i = 0; i < dense_src->numel(); i++) { for (int64_t i = 0; i < dense_src->numel(); i++) {
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <memory> #include <memory>
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/kernels/empty_kernel.h" #include "paddle/pten/kernels/empty_kernel.h"
#include "paddle/pten/kernels/full_kernel.h" #include "paddle/pten/kernels/full_kernel.h"
...@@ -30,15 +31,10 @@ using DDim = paddle::framework::DDim; ...@@ -30,15 +31,10 @@ using DDim = paddle::framework::DDim;
TEST(DEV_API, empty) { TEST(DEV_API, empty) {
// 1. create input // 1. create input
paddle::platform::DeviceContextPool& pool = pten::CPUContext dev_ctx;
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(paddle::platform::CPUPlace());
// 2. test API // 2. test API
auto out = pten::Empty<float>( auto out = pten::Empty<float>(dev_ctx, {3, 2}, pten::DataType::INT32);
*(static_cast<paddle::platform::CPUDeviceContext*>(dev_ctx)),
{3, 2},
pten::DataType::INT32);
// 3. check result // 3. check result
ASSERT_EQ(out.dims().size(), 2); ASSERT_EQ(out.dims().size(), 2);
...@@ -59,13 +55,9 @@ TEST(DEV_API, empty_like) { ...@@ -59,13 +55,9 @@ TEST(DEV_API, empty_like) {
auto* dense_x_data = dense_x.mutable_data<float>(); auto* dense_x_data = dense_x.mutable_data<float>();
dense_x_data[0] = 0; dense_x_data[0] = 0;
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(paddle::platform::CPUPlace());
// 2. test API // 2. test API
auto out = pten::EmptyLike<float>( pten::CPUContext dev_ctx;
*(static_cast<paddle::platform::CPUDeviceContext*>(dev_ctx)), dense_x); auto out = pten::EmptyLike<float>(dev_ctx, dense_x);
// 3. check result // 3. check result
ASSERT_EQ(out.dims().size(), 2); ASSERT_EQ(out.dims().size(), 2);
...@@ -79,16 +71,9 @@ TEST(DEV_API, full) { ...@@ -79,16 +71,9 @@ TEST(DEV_API, full) {
// 1. create input // 1. create input
float val = 1.0; float val = 1.0;
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(paddle::platform::CPUPlace());
// 2. test API // 2. test API
auto out = pten::Full<float>( pten::CPUContext dev_ctx;
*(static_cast<paddle::platform::CPUDeviceContext*>(dev_ctx)), auto out = pten::Full<float>(dev_ctx, {3, 2}, val, pten::DataType::FLOAT32);
{3, 2},
val,
pten::DataType::FLOAT32);
// 3. check result // 3. check result
ASSERT_EQ(out.dims().size(), 2); ASSERT_EQ(out.dims().size(), 2);
...@@ -115,15 +100,10 @@ TEST(DEV_API, full_like) { ...@@ -115,15 +100,10 @@ TEST(DEV_API, full_like) {
dense_x_data[0] = 0; dense_x_data[0] = 0;
float val = 1.0; float val = 1.0;
paddle::platform::DeviceContextPool& pool = pten::CPUContext dev_ctx;
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(paddle::platform::CPUPlace());
// 2. test API // 2. test API
auto out = pten::FullLike<float>( auto out = pten::FullLike<float>(dev_ctx, dense_x, val);
*(static_cast<paddle::platform::CPUDeviceContext*>(dev_ctx)),
dense_x,
val);
// 3. check result // 3. check result
ASSERT_EQ(out.dims().size(), 2); ASSERT_EQ(out.dims().size(), 2);
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <memory> #include <memory>
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/kernels/dot_kernel.h" #include "paddle/pten/kernels/dot_kernel.h"
#include "paddle/pten/api/lib/utils/allocator.h" #include "paddle/pten/api/lib/utils/allocator.h"
...@@ -52,15 +53,9 @@ TEST(DEV_API, dot) { ...@@ -52,15 +53,9 @@ TEST(DEV_API, dot) {
} }
} }
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(paddle::platform::CPUPlace());
// 2. test API // 2. test API
auto out = pten::Dot<float>( pten::CPUContext dev_ctx;
*(static_cast<paddle::platform::CPUDeviceContext*>(dev_ctx)), auto out = pten::Dot<float>(dev_ctx, dense_x, dense_y);
dense_x,
dense_y);
// 3. check result // 3. check result
ASSERT_EQ(out.dims().size(), 2); ASSERT_EQ(out.dims().size(), 2);
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <memory> #include <memory>
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/kernels/math_kernel.h" #include "paddle/pten/kernels/math_kernel.h"
#include "paddle/pten/api/lib/utils/allocator.h" #include "paddle/pten/api/lib/utils/allocator.h"
...@@ -54,16 +55,10 @@ TEST(DEV_API, add) { ...@@ -54,16 +55,10 @@ TEST(DEV_API, add) {
dense_y_data[i] = i * 2.0; dense_y_data[i] = i * 2.0;
} }
int axis = 1; int axis = 1;
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(paddle::platform::CPUPlace());
// 2. test API // 2. test API
auto dense_out = pten::Add<float>( pten::CPUContext dev_ctx;
*(static_cast<paddle::platform::CPUDeviceContext*>(dev_ctx)), auto dense_out = pten::Add<float>(dev_ctx, dense_x, dense_y, axis);
dense_x,
dense_y,
axis);
// 3. check result // 3. check result
ASSERT_EQ(dense_out.dims().size(), 2); ASSERT_EQ(dense_out.dims().size(), 2);
...@@ -107,16 +102,10 @@ TEST(DEV_API, subtract) { ...@@ -107,16 +102,10 @@ TEST(DEV_API, subtract) {
dense_y_data[i] = i * 2.0; dense_y_data[i] = i * 2.0;
} }
int axis = 1; int axis = 1;
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(paddle::platform::CPUPlace());
// 2. test API // 2. test API
auto dense_out = pten::Subtract<float>( pten::CPUContext dev_ctx;
*(static_cast<paddle::platform::CPUDeviceContext*>(dev_ctx)), auto dense_out = pten::Subtract<float>(dev_ctx, dense_x, dense_y, axis);
dense_x,
dense_y,
axis);
// 3. check result // 3. check result
ASSERT_EQ(dense_out.dims().size(), 2); ASSERT_EQ(dense_out.dims().size(), 2);
...@@ -160,16 +149,10 @@ TEST(DEV_API, divide) { ...@@ -160,16 +149,10 @@ TEST(DEV_API, divide) {
dense_y_data[i] = i * 2.0 + 1; dense_y_data[i] = i * 2.0 + 1;
} }
int axis = 1; int axis = 1;
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(paddle::platform::CPUPlace());
// 2. test API // 2. test API
auto dense_out = pten::Divide<float>( pten::CPUContext dev_ctx;
*(static_cast<paddle::platform::CPUDeviceContext*>(dev_ctx)), auto dense_out = pten::Divide<float>(dev_ctx, dense_x, dense_y, axis);
dense_x,
dense_y,
axis);
// 3. check result // 3. check result
ASSERT_EQ(dense_out.dims().size(), 2); ASSERT_EQ(dense_out.dims().size(), 2);
...@@ -213,16 +196,10 @@ TEST(DEV_API, multiply) { ...@@ -213,16 +196,10 @@ TEST(DEV_API, multiply) {
dense_y_data[i] = i * 2.0; dense_y_data[i] = i * 2.0;
} }
int axis = 1; int axis = 1;
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(paddle::platform::CPUPlace());
// 2. test API // 2. test API
auto dense_out = pten::Multiply<float>( pten::CPUContext dev_ctx;
*(static_cast<paddle::platform::CPUDeviceContext*>(dev_ctx)), auto dense_out = pten::Multiply<float>(dev_ctx, dense_x, dense_y, axis);
dense_x,
dense_y,
axis);
// 3. check result // 3. check result
ASSERT_EQ(dense_out.dims().size(), 2); ASSERT_EQ(dense_out.dims().size(), 2);
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <memory> #include <memory>
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/kernels/flatten_kernel.h" #include "paddle/pten/kernels/flatten_kernel.h"
#include "paddle/pten/api/lib/utils/allocator.h" #include "paddle/pten/api/lib/utils/allocator.h"
...@@ -52,16 +53,10 @@ TEST(DEV_API, flatten) { ...@@ -52,16 +53,10 @@ TEST(DEV_API, flatten) {
dense_x_data[i] = i; dense_x_data[i] = i;
} }
int start_axis = 1, stop_axis = 2; int start_axis = 1, stop_axis = 2;
paddle::platform::DeviceContextPool& pool = pten::CPUContext dev_ctx;
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(paddle::platform::CPUPlace());
// 2. test API // 2. test API
auto out = pten::Flatten<float>( auto out = pten::Flatten<float>(dev_ctx, dense_x, start_axis, stop_axis);
*(static_cast<paddle::platform::CPUDeviceContext*>(dev_ctx)),
dense_x,
start_axis,
stop_axis);
// 3. check result // 3. check result
std::vector<int> expect_shape = {3, 4, 3}; std::vector<int> expect_shape = {3, 4, 3};
......
...@@ -50,13 +50,9 @@ TEST(DEV_API, dot) { ...@@ -50,13 +50,9 @@ TEST(DEV_API, dot) {
} }
std::vector<float> sum(9, 6.0); std::vector<float> sum(9, 6.0);
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto* ctx = pool.Get(paddle::platform::CPUPlace());
// 2. test API // 2. test API
auto out = Matmul<float, CPUContext>( pten::CPUContext dev_ctx;
*(static_cast<CPUContext*>(ctx)), dense_x, dense_y, false, false); auto out = Matmul<float, CPUContext>(dev_ctx, dense_x, dense_y, false, false);
// 3. check result // 3. check result
ASSERT_EQ(out.dims().size(), 2); ASSERT_EQ(out.dims().size(), 2);
......
...@@ -42,17 +42,11 @@ TEST(DEV_API, mean) { ...@@ -42,17 +42,11 @@ TEST(DEV_API, mean) {
dense_x_data[i] = i * 1.0; dense_x_data[i] = i * 1.0;
sum += i * 1.0; sum += i * 1.0;
} }
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(paddle::platform::CPUPlace());
std::vector<int64_t> dims = {0, 1}; std::vector<int64_t> dims = {0, 1};
// 2. test API // 2. test API
auto out = pten::Mean<float>( pten::CPUContext dev_ctx;
*(static_cast<paddle::platform::CPUDeviceContext*>(dev_ctx)), auto out = pten::Mean<float>(dev_ctx, dense_x, dims, false);
dense_x,
dims,
false);
// 3. check result // 3. check result
ASSERT_EQ(out.dims().size(), 1); ASSERT_EQ(out.dims().size(), 1);
......
...@@ -42,16 +42,11 @@ TEST(DEV_API, reshape) { ...@@ -42,16 +42,11 @@ TEST(DEV_API, reshape) {
for (int i = 0; i < dense_x.numel(); i++) { for (int i = 0; i < dense_x.numel(); i++) {
dense_x_data[i] = i; dense_x_data[i] = i;
} }
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(paddle::platform::CPUPlace());
std::vector<int64_t> shape{12, 3}; std::vector<int64_t> shape{12, 3};
// 2. test API // 2. test API
auto out = pten::Reshape<float>( pten::CPUContext dev_ctx;
*(static_cast<paddle::platform::CPUDeviceContext*>(dev_ctx)), auto out = pten::Reshape<float>(dev_ctx, dense_x, shape);
dense_x,
shape);
// 3. check result // 3. check result
std::vector<int64_t> expect_shape = {12, 3}; std::vector<int64_t> expect_shape = {12, 3};
ASSERT_EQ(out.dims()[0], expect_shape[0]); ASSERT_EQ(out.dims()[0], expect_shape[0]);
......
...@@ -44,17 +44,10 @@ TEST(DEV_API, scale) { ...@@ -44,17 +44,10 @@ TEST(DEV_API, scale) {
float bias = 1; float bias = 1;
bool bias_after_scale = true; bool bias_after_scale = true;
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(paddle::platform::CPUPlace());
// 2. test API // 2. test API
auto out = pten::Scale<float>( pten::CPUContext dev_ctx;
*(static_cast<paddle::platform::CPUDeviceContext*>(dev_ctx)), auto out =
dense_x, pten::Scale<float>(dev_ctx, dense_x, scale, bias, bias_after_scale);
scale,
bias,
bias_after_scale);
// 3. check result // 3. check result
ASSERT_EQ(out.dims().size(), 2); ASSERT_EQ(out.dims().size(), 2);
...@@ -88,17 +81,10 @@ TEST(DEV_API, scale_host) { ...@@ -88,17 +81,10 @@ TEST(DEV_API, scale_host) {
float bias = 1; float bias = 1;
bool bias_after_scale = true; bool bias_after_scale = true;
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(paddle::platform::CPUPlace());
// 2. test API // 2. test API
auto out = pten::Scale<float>( pten::CPUContext dev_ctx;
*(static_cast<paddle::platform::CPUDeviceContext*>(dev_ctx)), auto out =
dense_x, pten::Scale<float>(dev_ctx, dense_x, scale, bias, bias_after_scale);
scale,
bias,
bias_after_scale);
// 3. check result // 3. check result
ASSERT_EQ(out.dims().size(), 2); ASSERT_EQ(out.dims().size(), 2);
......
...@@ -42,18 +42,12 @@ TEST(DEV_API, sum) { ...@@ -42,18 +42,12 @@ TEST(DEV_API, sum) {
dense_x_data[i] = i * 1.0; dense_x_data[i] = i * 1.0;
sum += i * 1.0; sum += i * 1.0;
} }
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(paddle::platform::CPUPlace());
std::vector<int64_t> axis = {0, 1}; std::vector<int64_t> axis = {0, 1};
pten::CPUContext dev_ctx;
// 2. test API // 2. test API
auto out = pten::Sum<float>( auto out =
*(static_cast<paddle::platform::CPUDeviceContext*>(dev_ctx)), pten::Sum<float>(dev_ctx, dense_x, axis, pten::DataType::FLOAT32, false);
dense_x,
axis,
pten::DataType::FLOAT32,
false);
// 3. check result // 3. check result
ASSERT_EQ(out.dims().size(), 1); ASSERT_EQ(out.dims().size(), 1);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册