From 064bc4b819b505682fa46dedfb2fc4d8d5813c61 Mon Sep 17 00:00:00 2001 From: Wilber Date: Fri, 21 Jan 2022 08:49:34 +0800 Subject: [PATCH] [PTEN] Add cpu context (#38979) * add cpu_context. * update * update * update * update * update * fix ci problem * fix npu ci problem * update * fix ci compile --- paddle/fluid/distributed/service/brpc_utils.h | 3 - .../fluid/distributed/service/heter_client.h | 3 - .../fluid/distributed/service/heter_server.h | 3 - paddle/fluid/distributed/service/server.h | 3 - .../eager_generated/backwards/scale_node.cc | 40 +- paddle/fluid/eager/legacy/prepared_operator.h | 3 - .../details/fetch_async_op_handle.cc | 6 - .../framework/details/fetch_async_op_handle.h | 4 - .../fluid/framework/details/fetch_op_handle.h | 3 - .../fluid/framework/details/op_handle_base.h | 6 - .../details/scale_loss_grad_op_handle.h | 3 - paddle/fluid/framework/device_worker.h | 3 - paddle/fluid/framework/garbage_collector.h | 6 - paddle/fluid/framework/lod_tensor.cc | 6 - paddle/fluid/framework/lod_tensor.h | 6 - paddle/fluid/framework/pten_utils.h | 11 + paddle/fluid/framework/selected_rows.cc | 6 - paddle/fluid/imperative/parallel_context.h | 6 +- paddle/fluid/imperative/prepared_operator.cc | 5 + paddle/fluid/imperative/prepared_operator.h | 3 - paddle/fluid/imperative/reducer.h | 5 - paddle/fluid/memory/malloc.h | 8 +- paddle/fluid/operators/assign_op.h | 6 - paddle/fluid/operators/cast_op.h | 5 +- paddle/fluid/operators/cholesky_solve_op.h | 5 +- paddle/fluid/operators/conj_op.h | 5 +- paddle/fluid/operators/dot_op.h | 12 +- .../elementwise/elementwise_add_op.h | 5 +- .../elementwise/elementwise_div_op.h | 5 +- .../elementwise/elementwise_mul_op.h | 6 +- .../elementwise/elementwise_sub_op.h | 6 +- paddle/fluid/operators/fill_any_like_op.h | 5 +- .../fill_constant_batch_size_like_op.h | 3 +- .../fill_constant_batch_size_like_op_npu.cc | 2 +- paddle/fluid/operators/fill_constant_op.h | 12 +- paddle/fluid/operators/flatten_op.h | 13 +- paddle/fluid/operators/layer_norm_op.h | 1 - paddle/fluid/operators/lu_op.h | 12 +- paddle/fluid/operators/math/blas_impl.h | 387 ++++++++++++++++++ paddle/fluid/operators/math/math_function.cu | 16 + paddle/fluid/operators/matmul_v2_op.h | 21 +- paddle/fluid/operators/memcpy_d2h_op.h | 6 - paddle/fluid/operators/memcpy_h2d_op.h | 6 - paddle/fluid/operators/memcpy_op.h | 6 - .../pscore/heter_listen_and_serv_op.h | 3 - paddle/fluid/operators/recurrent_op.h | 6 - paddle/fluid/operators/reduce_ops/reduce_op.h | 7 +- paddle/fluid/operators/reshape_op.cc | 11 +- paddle/fluid/operators/scale_op.h | 5 +- paddle/fluid/operators/seed_op.cu | 2 +- paddle/fluid/operators/sign_op.h | 6 +- paddle/fluid/operators/transfer_layout_op.h | 6 - paddle/fluid/platform/CMakeLists.txt | 2 +- .../platform/device/mlu/device_context.h | 2 - paddle/fluid/platform/device_context.cc | 51 +-- paddle/fluid/platform/device_context.h | 24 +- paddle/fluid/platform/for_range.h | 16 + paddle/fluid/platform/transform.h | 18 + paddle/pten/backends/CMakeLists.txt | 1 + paddle/pten/backends/all_context.h | 4 +- paddle/pten/backends/cpu/CMakeLists.txt | 6 + paddle/pten/backends/cpu/cpu_context.cc | 93 +++++ paddle/pten/backends/cpu/cpu_context.h | 46 ++- paddle/pten/backends/cpu/forwards.h | 21 + paddle/pten/core/CMakeLists.txt | 2 + paddle/pten/core/device_context.cc | 56 +++ paddle/pten/core/device_context.h | 86 ++++ paddle/pten/tests/core/CMakeLists.txt | 1 + paddle/pten/tests/core/test_device_context.cc | 68 +++ .../pten/tests/kernels/test_cast_dev_api.cc | 10 +- .../pten/tests/kernels/test_conj_dev_api.cc | 8 +- .../pten/tests/kernels/test_copy_dev_api.cc | 6 +- .../tests/kernels/test_creation_dev_api.cc | 38 +- paddle/pten/tests/kernels/test_dot_dev_api.cc | 11 +- .../tests/kernels/test_elementwise_dev_api.cc | 41 +- .../tests/kernels/test_flatten_dev_api.cc | 11 +- .../pten/tests/kernels/test_matmul_dev_api.cc | 8 +- .../pten/tests/kernels/test_mean_dev_api.cc | 12 +- .../tests/kernels/test_reshape_dev_api.cc | 9 +- .../pten/tests/kernels/test_scale_dev_api.cc | 26 +- paddle/pten/tests/kernels/test_sum_dev_api.cc | 12 +- 81 files changed, 1039 insertions(+), 383 deletions(-) create mode 100644 paddle/pten/backends/cpu/CMakeLists.txt create mode 100644 paddle/pten/backends/cpu/cpu_context.cc create mode 100644 paddle/pten/backends/cpu/forwards.h create mode 100644 paddle/pten/core/device_context.cc create mode 100644 paddle/pten/core/device_context.h create mode 100644 paddle/pten/tests/core/test_device_context.cc diff --git a/paddle/fluid/distributed/service/brpc_utils.h b/paddle/fluid/distributed/service/brpc_utils.h index f24e2889b66..47de71d2087 100644 --- a/paddle/fluid/distributed/service/brpc_utils.h +++ b/paddle/fluid/distributed/service/brpc_utils.h @@ -42,9 +42,6 @@ namespace framework { class Scope; class Variable; } // namespace framework -namespace platform { -class DeviceContext; -} // namespace platform } // namespace paddle namespace paddle { diff --git a/paddle/fluid/distributed/service/heter_client.h b/paddle/fluid/distributed/service/heter_client.h index 5fa49bc2411..7ba47ad9a5d 100644 --- a/paddle/fluid/distributed/service/heter_client.h +++ b/paddle/fluid/distributed/service/heter_client.h @@ -37,9 +37,6 @@ namespace paddle { namespace framework { class Scope; } // namespace framework -namespace platform { -class DeviceContext; -} // namespace platform } // namespace paddle namespace paddle { diff --git a/paddle/fluid/distributed/service/heter_server.h b/paddle/fluid/distributed/service/heter_server.h index 201074810cf..094ee603641 100644 --- a/paddle/fluid/distributed/service/heter_server.h +++ b/paddle/fluid/distributed/service/heter_server.h @@ -48,9 +48,6 @@ class Executor; class ProgramDesc; class Scope; } // namespace framework -namespace platform { -class DeviceContext; -} // namespace platform } // namespace paddle DECLARE_double(eager_delete_tensor_gb); diff --git a/paddle/fluid/distributed/service/server.h b/paddle/fluid/distributed/service/server.h index dffe19545ce..ebebedc80ef 100644 --- a/paddle/fluid/distributed/service/server.h +++ b/paddle/fluid/distributed/service/server.h @@ -48,9 +48,6 @@ class Executor; class ProgramDesc; class Scope; } // namespace framework -namespace platform { -class DeviceContext; -} // namespace platform } // namespace paddle namespace paddle { diff --git a/paddle/fluid/eager/api/generated/eager_generated/backwards/scale_node.cc b/paddle/fluid/eager/api/generated/eager_generated/backwards/scale_node.cc index 99f6c7a8353..cd91209c9cc 100644 --- a/paddle/fluid/eager/api/generated/eager_generated/backwards/scale_node.cc +++ b/paddle/fluid/eager/api/generated/eager_generated/backwards/scale_node.cc @@ -33,31 +33,39 @@ static void ScaleDeviceDispatch(const pten::DenseTensor& dense_tensor, pten::DenseTensor* dense_out) { switch (dense_tensor.dtype()) { case pten::DataType::FLOAT64: { - pten::ScaleKernel( - dev_ctx, dense_tensor /* tensor */, scale /* scale */, - bias /* bias */, bias_after_scale /* bias_after_scale */, - dense_out /* out tensor */); + pten::ScaleKernel::TYPE>( + static_cast::TYPE&>(dev_ctx), + dense_tensor /* tensor */, scale /* scale */, bias /* bias */, + bias_after_scale /* bias_after_scale */, dense_out /* out tensor */); break; } case pten::DataType::FLOAT32: { - pten::ScaleKernel( - dev_ctx, dense_tensor /* tensor */, scale /* scale */, - bias /* bias */, bias_after_scale /* bias_after_scale */, - dense_out /* out tensor */); + pten::ScaleKernel::TYPE>( + static_cast::TYPE&>(dev_ctx), + dense_tensor /* tensor */, scale /* scale */, bias /* bias */, + bias_after_scale /* bias_after_scale */, dense_out /* out tensor */); break; } case pten::DataType::INT64: { - pten::ScaleKernel( - dev_ctx, dense_tensor /* tensor */, scale /* scale */, - bias /* bias */, bias_after_scale /* bias_after_scale */, - dense_out /* out tensor */); + pten::ScaleKernel::TYPE>( + static_cast::TYPE&>(dev_ctx), + dense_tensor /* tensor */, scale /* scale */, bias /* bias */, + bias_after_scale /* bias_after_scale */, dense_out /* out tensor */); break; } case pten::DataType::INT32: { - pten::ScaleKernel( - dev_ctx, dense_tensor /* tensor */, scale /* scale */, - bias /* bias */, bias_after_scale /* bias_after_scale */, - dense_out /* out tensor */); + pten::ScaleKernel::TYPE>( + static_cast::TYPE&>(dev_ctx), + dense_tensor /* tensor */, scale /* scale */, bias /* bias */, + bias_after_scale /* bias_after_scale */, dense_out /* out tensor */); break; } default: { diff --git a/paddle/fluid/eager/legacy/prepared_operator.h b/paddle/fluid/eager/legacy/prepared_operator.h index 87720fd8f00..7c448a76296 100644 --- a/paddle/fluid/eager/legacy/prepared_operator.h +++ b/paddle/fluid/eager/legacy/prepared_operator.h @@ -31,9 +31,6 @@ namespace paddle { namespace framework { class Variable; } // namespace framework -namespace platform { -class DeviceContext; -} // namespace platform } // namespace paddle namespace pten { diff --git a/paddle/fluid/framework/details/fetch_async_op_handle.cc b/paddle/fluid/framework/details/fetch_async_op_handle.cc index f59d947e279..69741bd3c97 100644 --- a/paddle/fluid/framework/details/fetch_async_op_handle.cc +++ b/paddle/fluid/framework/details/fetch_async_op_handle.cc @@ -18,12 +18,6 @@ #include "paddle/fluid/platform/profiler.h" -namespace paddle { -namespace platform { -class DeviceContext; -} // namespace platform -} // namespace paddle - namespace paddle { namespace framework { namespace details { diff --git a/paddle/fluid/framework/details/fetch_async_op_handle.h b/paddle/fluid/framework/details/fetch_async_op_handle.h index 3e9563ab1ed..da62902c2ec 100644 --- a/paddle/fluid/framework/details/fetch_async_op_handle.h +++ b/paddle/fluid/framework/details/fetch_async_op_handle.h @@ -33,10 +33,6 @@ namespace ir { class Node; } // namespace ir } // namespace framework - -namespace platform { -class DeviceContext; -} // namespace platform } // namespace paddle namespace paddle { diff --git a/paddle/fluid/framework/details/fetch_op_handle.h b/paddle/fluid/framework/details/fetch_op_handle.h index 41deeb0af27..c2e423869ea 100644 --- a/paddle/fluid/framework/details/fetch_op_handle.h +++ b/paddle/fluid/framework/details/fetch_op_handle.h @@ -28,9 +28,6 @@ namespace ir { class Node; } // namespace ir } // namespace framework -namespace platform { -class DeviceContext; -} // namespace platform } // namespace paddle namespace paddle { diff --git a/paddle/fluid/framework/details/op_handle_base.h b/paddle/fluid/framework/details/op_handle_base.h index cbad39c8716..c9021c84cdf 100644 --- a/paddle/fluid/framework/details/op_handle_base.h +++ b/paddle/fluid/framework/details/op_handle_base.h @@ -25,12 +25,6 @@ #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/macros.h" -namespace paddle { -namespace platform { -class DeviceContext; -} // namespace platform -} // namespace paddle - namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/details/scale_loss_grad_op_handle.h b/paddle/fluid/framework/details/scale_loss_grad_op_handle.h index 88fe02a749f..c387450a69e 100644 --- a/paddle/fluid/framework/details/scale_loss_grad_op_handle.h +++ b/paddle/fluid/framework/details/scale_loss_grad_op_handle.h @@ -28,9 +28,6 @@ namespace ir { class Node; } // namespace ir } // namespace framework -namespace platform { -class DeviceContext; -} // namespace platform } // namespace paddle namespace paddle { diff --git a/paddle/fluid/framework/device_worker.h b/paddle/fluid/framework/device_worker.h index edb87a378dd..d8b14fc0d4c 100644 --- a/paddle/fluid/framework/device_worker.h +++ b/paddle/fluid/framework/device_worker.h @@ -46,9 +46,6 @@ namespace framework { class ProgramDesc; class Scope; } // namespace framework -namespace platform { -class DeviceContext; -} // namespace platform } // namespace paddle #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) diff --git a/paddle/fluid/framework/garbage_collector.h b/paddle/fluid/framework/garbage_collector.h index dbb3ab7e9e6..f5d79d864b5 100644 --- a/paddle/fluid/framework/garbage_collector.h +++ b/paddle/fluid/framework/garbage_collector.h @@ -26,12 +26,6 @@ #include "paddle/fluid/platform/device/mlu/device_context.h" #endif -namespace paddle { -namespace platform { -class DeviceContext; -} // namespace platform -} // namespace paddle - namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/lod_tensor.cc b/paddle/fluid/framework/lod_tensor.cc index 48ba7cc0a2a..a4b9fff8ecd 100644 --- a/paddle/fluid/framework/lod_tensor.cc +++ b/paddle/fluid/framework/lod_tensor.cc @@ -18,12 +18,6 @@ limitations under the License. */ #include "paddle/fluid/framework/version.h" -namespace paddle { -namespace platform { -class DeviceContext; -} // namespace platform -} // namespace paddle - namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/lod_tensor.h b/paddle/fluid/framework/lod_tensor.h index 41cd6b83fd1..14727c190b5 100644 --- a/paddle/fluid/framework/lod_tensor.h +++ b/paddle/fluid/framework/lod_tensor.h @@ -27,12 +27,6 @@ limitations under the License. */ #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/place.h" -namespace paddle { -namespace platform { -class DeviceContext; -} // namespace platform -} // namespace paddle - namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/pten_utils.h b/paddle/fluid/framework/pten_utils.h index 8bbd4f7f3c9..a4493f3d3e5 100644 --- a/paddle/fluid/framework/pten_utils.h +++ b/paddle/fluid/framework/pten_utils.h @@ -75,5 +75,16 @@ class KernelArgsNameMaker { void SetAllocationForOutputTenosr(pten::DenseTensor* tensor, const platform::Place& place); +// TODO(Wilber): support others device context. +template +struct ConvertToPtenContext { + using TYPE = T; +}; + +template <> +struct ConvertToPtenContext { + using TYPE = pten::CPUContext; +}; + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/selected_rows.cc b/paddle/fluid/framework/selected_rows.cc index c67653953f8..6cad0915be7 100644 --- a/paddle/fluid/framework/selected_rows.cc +++ b/paddle/fluid/framework/selected_rows.cc @@ -14,12 +14,6 @@ limitations under the License. */ #include "paddle/fluid/framework/selected_rows.h" -namespace paddle { -namespace platform { -class DeviceContext; -} // namespace platform -} // namespace paddle - namespace paddle { namespace framework { diff --git a/paddle/fluid/imperative/parallel_context.h b/paddle/fluid/imperative/parallel_context.h index 8bdfccc1442..eafddf5fcae 100644 --- a/paddle/fluid/imperative/parallel_context.h +++ b/paddle/fluid/imperative/parallel_context.h @@ -16,17 +16,13 @@ #include #include +#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/place.h" namespace paddle { -namespace platform { -class DeviceContext; -} // namespace platform - namespace framework { class Variable; } // namespace framework - } // namespace paddle namespace paddle { diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index bb08191af98..d28595a6a4c 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -173,6 +173,11 @@ PreparedOp PrepareImpl(const NameVarMap& ins, << " | kernel key: " << pt_kernel_key << " | kernel: " << pt_kernel; + if (platform::is_cpu_place(expected_kernel_key.place_)) { + auto* cpu_ctx = pool.Get(paddle::platform::CPUPlace()); + return PreparedOp(op, ctx, expected_kernel_key, pt_kernel_signature, + pt_kernel, cpu_ctx); + } // TODO(chenweihang): using CPUKernel when miss device kernel case return PreparedOp(op, ctx, expected_kernel_key, pt_kernel_signature, pt_kernel, dev_ctx); diff --git a/paddle/fluid/imperative/prepared_operator.h b/paddle/fluid/imperative/prepared_operator.h index 09cc480fe17..1a66fe0a056 100644 --- a/paddle/fluid/imperative/prepared_operator.h +++ b/paddle/fluid/imperative/prepared_operator.h @@ -37,9 +37,6 @@ namespace paddle { namespace framework { class Variable; } // namespace framework -namespace platform { -class DeviceContext; -} // namespace platform } // namespace paddle namespace paddle { diff --git a/paddle/fluid/imperative/reducer.h b/paddle/fluid/imperative/reducer.h index 3c03babc52c..b99d7adc0c7 100644 --- a/paddle/fluid/imperative/reducer.h +++ b/paddle/fluid/imperative/reducer.h @@ -33,11 +33,6 @@ #include "paddle/fluid/platform/for_range.h" namespace paddle { -namespace platform { -class DeviceContext; - -} // namespace platform - namespace imperative { class ParallelContext; class VarBase; diff --git a/paddle/fluid/memory/malloc.h b/paddle/fluid/memory/malloc.h index 8830c46a177..6443e91f08c 100644 --- a/paddle/fluid/memory/malloc.h +++ b/paddle/fluid/memory/malloc.h @@ -19,13 +19,9 @@ limitations under the License. */ #include "paddle/fluid/memory/allocation/allocator.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/stream/stream.h" +#include "paddle/pten/core/device_context.h" namespace paddle { - -namespace platform { -class DeviceContext; -} // platform - namespace memory { using pten::Allocation; @@ -37,7 +33,7 @@ extern std::shared_ptr AllocShared(const platform::Place& place, extern AllocationPtr Alloc(const platform::Place& place, size_t size); -extern AllocationPtr Alloc(const platform::DeviceContext& dev_ctx, size_t size); +extern AllocationPtr Alloc(const pten::DeviceContext& dev_ctx, size_t size); extern uint64_t Release(const platform::Place& place); diff --git a/paddle/fluid/operators/assign_op.h b/paddle/fluid/operators/assign_op.h index 1dd28c9389d..5fe2ebb2074 100644 --- a/paddle/fluid/operators/assign_op.h +++ b/paddle/fluid/operators/assign_op.h @@ -19,12 +19,6 @@ limitations under the License. */ #include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/platform/device_context.h" -namespace paddle { -namespace platform { -class DeviceContext; -} // namespace platform -} // namespace paddle - namespace pten { class DenseTensor; } // namespace pten diff --git a/paddle/fluid/operators/cast_op.h b/paddle/fluid/operators/cast_op.h index e82eeb240c1..9bf08d072da 100644 --- a/paddle/fluid/operators/cast_op.h +++ b/paddle/fluid/operators/cast_op.h @@ -67,7 +67,10 @@ class CastOpKernel : public framework::OpKernel { static_cast(out_dtype)); // call new kernel - pten::CastKernel(dev_ctx, *in, pt_out_dtype, out); + pten::CastKernel( + static_cast::TYPE&>(dev_ctx), + *in, pt_out_dtype, out); } }; diff --git a/paddle/fluid/operators/cholesky_solve_op.h b/paddle/fluid/operators/cholesky_solve_op.h index 157679f4fc9..4b1d075de91 100644 --- a/paddle/fluid/operators/cholesky_solve_op.h +++ b/paddle/fluid/operators/cholesky_solve_op.h @@ -202,7 +202,10 @@ class CholeskySolveGradKernel : public framework::OpKernel { commonterm_for_range(commonterm_functor); commonterm_conj = helper.Transpose(commonterm_conj); - pten::AddKernel(dev_ctx, commonterm, commonterm_conj, -1, &commonterm); + pten::AddKernel( + static_cast::TYPE &>(dev_ctx), + commonterm, commonterm_conj, -1, &commonterm); auto mat_dim_u = math::CreateMatrixDescriptor(u_bst.dims(), 0, false); auto mat_dim_c = diff --git a/paddle/fluid/operators/conj_op.h b/paddle/fluid/operators/conj_op.h index e6958ed18b8..9a5a467cea3 100644 --- a/paddle/fluid/operators/conj_op.h +++ b/paddle/fluid/operators/conj_op.h @@ -36,7 +36,10 @@ class ConjKernel : public framework::OpKernel { auto& dev_ctx = context.device_context(); // call new kernel - pten::ConjKernel(dev_ctx, *x, out); + pten::ConjKernel( + static_cast::TYPE&>(dev_ctx), + *x, out); } }; diff --git a/paddle/fluid/operators/dot_op.h b/paddle/fluid/operators/dot_op.h index ed8eb64d9e1..c5d43ef0126 100644 --- a/paddle/fluid/operators/dot_op.h +++ b/paddle/fluid/operators/dot_op.h @@ -41,7 +41,11 @@ class DotKernel : public framework::OpKernel { out->mutable_data(x->place()); // call new kernel - pten::DotKernel(dev_ctx, *x, *y, out); + pten::DotKernel::TYPE>( + static_cast::TYPE&>(dev_ctx), + *x, *y, out); } }; @@ -61,8 +65,10 @@ class DotGradKernel : public framework::OpKernel { auto& dev_ctx = ctx.device_context(); // call new kernel - pten::DotGradKernel(dev_ctx, *tensor_x, *tensor_y, *tensor_dout, - tensor_dx, tensor_dy); + pten::DotGradKernel( + static_cast::TYPE&>(dev_ctx), + *tensor_x, *tensor_y, *tensor_dout, tensor_dx, tensor_dy); } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.h b/paddle/fluid/operators/elementwise/elementwise_add_op.h index 622a6d7edb7..a4897a06d56 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.h @@ -61,7 +61,10 @@ class ElementwiseAddKernel : public framework::OpKernel { auto pt_x = paddle::experimental::MakePtenDenseTensor(*x); auto pt_y = paddle::experimental::MakePtenDenseTensor(*y); auto pt_z = paddle::experimental::MakePtenDenseTensor(*z); - pten::AddKernel(dev_ctx, *pt_x.get(), *pt_y.get(), axis, pt_z.get()); + pten::AddKernel( + static_cast::TYPE &>(dev_ctx), + *pt_x.get(), *pt_y.get(), axis, pt_z.get()); } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.h b/paddle/fluid/operators/elementwise/elementwise_div_op.h index d9f7bbc56a9..44f695278dc 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.h @@ -51,7 +51,10 @@ class ElementwiseDivKernel : public framework::OpKernel { auto pt_x = paddle::experimental::MakePtenDenseTensor(*x); auto pt_y = paddle::experimental::MakePtenDenseTensor(*y); auto pt_z = paddle::experimental::MakePtenDenseTensor(*z); - pten::DivideKernel(dev_ctx, *pt_x.get(), *pt_y.get(), axis, pt_z.get()); + pten::DivideKernel( + static_cast::TYPE&>(dev_ctx), + *pt_x.get(), *pt_y.get(), axis, pt_z.get()); } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.h b/paddle/fluid/operators/elementwise/elementwise_mul_op.h index 687340b668a..d918407930d 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.h @@ -124,8 +124,10 @@ class ElementwiseMulKernel : public framework::OpKernel { auto pt_x = paddle::experimental::MakePtenDenseTensor(*x_lod); auto pt_y = paddle::experimental::MakePtenDenseTensor(*y); auto pt_z = paddle::experimental::MakePtenDenseTensor(*z_lod); - pten::MultiplyKernel(dev_ctx, *pt_x.get(), *pt_y.get(), axis, - pt_z.get()); + pten::MultiplyKernel( + static_cast::TYPE&>(dev_ctx), + *pt_x.get(), *pt_y.get(), axis, pt_z.get()); } else { PADDLE_THROW(platform::errors::InvalidArgument( "X's type[%s] is not supported by elementwise_op. X's type should be " diff --git a/paddle/fluid/operators/elementwise/elementwise_sub_op.h b/paddle/fluid/operators/elementwise/elementwise_sub_op.h index 0d889ef26c9..46d4a93e804 100644 --- a/paddle/fluid/operators/elementwise/elementwise_sub_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_sub_op.h @@ -51,8 +51,10 @@ class ElementwiseSubKernel : public framework::OpKernel { auto pt_x = paddle::experimental::MakePtenDenseTensor(*x); auto pt_y = paddle::experimental::MakePtenDenseTensor(*y); auto pt_z = paddle::experimental::MakePtenDenseTensor(*z); - pten::SubtractKernel(dev_ctx, *pt_x.get(), *pt_y.get(), axis, - pt_z.get()); + pten::SubtractKernel( + static_cast::TYPE&>(dev_ctx), + *pt_x.get(), *pt_y.get(), axis, pt_z.get()); } }; diff --git a/paddle/fluid/operators/fill_any_like_op.h b/paddle/fluid/operators/fill_any_like_op.h index d5b3d89c3d7..3ebda1a074a 100644 --- a/paddle/fluid/operators/fill_any_like_op.h +++ b/paddle/fluid/operators/fill_any_like_op.h @@ -62,7 +62,10 @@ class FillAnyLikeKernel : public framework::OpKernel { const auto& dev_ctx = context.template device_context(); // call new kernel - pten::FullLikeKernel(dev_ctx, value, out); + pten::FullLikeKernel( + static_cast::TYPE&>(dev_ctx), + value, out); } }; diff --git a/paddle/fluid/operators/fill_constant_batch_size_like_op.h b/paddle/fluid/operators/fill_constant_batch_size_like_op.h index 432a9968ab0..4c90daa39f9 100644 --- a/paddle/fluid/operators/fill_constant_batch_size_like_op.h +++ b/paddle/fluid/operators/fill_constant_batch_size_like_op.h @@ -57,9 +57,9 @@ class FillConstantBatchSizeLikeOpKernel : public framework::OpKernel { } platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto &dev_ctx = *pool.Get(ctx.GetPlace()); bool cpu_place = force_cpu || ctx.GetPlace() == platform::CPUPlace(); if (cpu_place) { + auto &dev_ctx = *pool.Get(platform::CPUPlace()); math::SetConstant functor; out->mutable_data(platform::CPUPlace(), data_type); functor(reinterpret_cast(dev_ctx), @@ -67,6 +67,7 @@ class FillConstantBatchSizeLikeOpKernel : public framework::OpKernel { } #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (!cpu_place) { + auto &dev_ctx = *pool.Get(ctx.GetPlace()); math::SetConstant functor; out->mutable_data(ctx.GetPlace(), data_type); functor(reinterpret_cast(dev_ctx), diff --git a/paddle/fluid/operators/fill_constant_batch_size_like_op_npu.cc b/paddle/fluid/operators/fill_constant_batch_size_like_op_npu.cc index 7b02781ff0c..6b07b021d13 100644 --- a/paddle/fluid/operators/fill_constant_batch_size_like_op_npu.cc +++ b/paddle/fluid/operators/fill_constant_batch_size_like_op_npu.cc @@ -67,9 +67,9 @@ class FillConstantBatchSizeLikeOpNPUKernel : public framework::OpKernel { } platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto &dev_ctx = *pool.Get(ctx.GetPlace()); bool cpu_place = force_cpu || ctx.GetPlace() == platform::CPUPlace(); if (cpu_place) { + auto &dev_ctx = *pool.Get(platform::CPUPlace()); math::SetConstant functor; out->mutable_data(platform::CPUPlace(), data_type); functor(reinterpret_cast(dev_ctx), diff --git a/paddle/fluid/operators/fill_constant_op.h b/paddle/fluid/operators/fill_constant_op.h index 32cd07c916b..9e9bd2e0fbb 100644 --- a/paddle/fluid/operators/fill_constant_op.h +++ b/paddle/fluid/operators/fill_constant_op.h @@ -102,7 +102,6 @@ class FillConstantKernel : public framework::OpKernel { } platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto &dev_ctx = *pool.Get(ctx.GetPlace()); int actual_place = place_type; if (actual_place == -1) { @@ -123,12 +122,14 @@ class FillConstantKernel : public framework::OpKernel { : ""); tensor->mutable_data(platform::CPUPlace(), data_type); math::SetConstant functor; + auto &dev_ctx = *pool.Get(platform::CPUPlace()); functor(reinterpret_cast(dev_ctx), tensor, static_cast(value)); } else if (actual_place == 1) { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) tensor->mutable_data(ctx.GetPlace(), data_type); math::SetConstant functor; + auto &dev_ctx = *pool.Get(ctx.GetPlace()); functor(reinterpret_cast(dev_ctx), tensor, static_cast(value)); #else @@ -138,9 +139,11 @@ class FillConstantKernel : public framework::OpKernel { } else if (actual_place == 2) { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) tensor->mutable_data(platform::CUDAPinnedPlace(), data_type); - math::SetConstant functor; - functor(reinterpret_cast(dev_ctx), - tensor, static_cast(value)); + math::SetConstant functor; + auto &dev_ctx = *pool.Get(platform::CUDAPinnedPlace()); + functor( + reinterpret_cast(dev_ctx), + tensor, static_cast(value)); #else PADDLE_THROW(platform::errors::PreconditionNotMet( "PaddlePaddle should compile with GPU.")); @@ -149,6 +152,7 @@ class FillConstantKernel : public framework::OpKernel { #ifdef PADDLE_WITH_XPU tensor->mutable_data(ctx.GetPlace(), data_type); math::SetConstant functor; + auto &dev_ctx = *pool.Get(ctx.GetPlace()); functor(reinterpret_cast(dev_ctx), tensor, static_cast(value)); #else diff --git a/paddle/fluid/operators/flatten_op.h b/paddle/fluid/operators/flatten_op.h index 624bc9c567f..2a9c2b27d23 100644 --- a/paddle/fluid/operators/flatten_op.h +++ b/paddle/fluid/operators/flatten_op.h @@ -132,8 +132,11 @@ class FlattenContiguousRangeKernel : public framework::OpKernel { auto &dev_ctx = context.device_context(); // call new kernel - pten::FlattenKernel(dev_ctx, *in, start_axis, stop_axis, - out); + pten::FlattenKernel::TYPE>( + static_cast::TYPE &>(dev_ctx), + *in, start_axis, stop_axis, out); } }; @@ -150,7 +153,11 @@ class FlattenContiguousRangeGradKernel : public framework::OpKernel { auto &dev_ctx = ctx.device_context(); // call new kernel - pten::FlattenGradKernel(dev_ctx, *d_out, *xshape, d_x); + pten::FlattenGradKernel::TYPE>( + static_cast::TYPE &>(dev_ctx), + *d_out, *xshape, d_x); } }; diff --git a/paddle/fluid/operators/layer_norm_op.h b/paddle/fluid/operators/layer_norm_op.h index 970ec694120..ad7c0cc218b 100644 --- a/paddle/fluid/operators/layer_norm_op.h +++ b/paddle/fluid/operators/layer_norm_op.h @@ -31,7 +31,6 @@ namespace paddle { namespace platform { class CPUDeviceContext; class CUDADeviceContext; -class DeviceContext; } // namespace platform } // namespace paddle diff --git a/paddle/fluid/operators/lu_op.h b/paddle/fluid/operators/lu_op.h index f78c5b9d361..6beef1add8e 100644 --- a/paddle/fluid/operators/lu_op.h +++ b/paddle/fluid/operators/lu_op.h @@ -221,7 +221,11 @@ void Tensor_Add(const DeviceContext& dev_ctx, const framework::Tensor& src1, out->Resize(src1.dims()); out->mutable_data(dev_ctx.GetPlace()); - pten::AddKernel(dev_ctx, src1, src2, -1, out); + pten::AddKernel< + T, typename paddle::framework::ConvertToPtenContext::TYPE>( + static_cast::TYPE&>(dev_ctx), + src1, src2, -1, out); } template @@ -230,7 +234,11 @@ void Tensor_Sub(const DeviceContext& dev_ctx, const framework::Tensor& src1, out->Resize(src1.dims()); out->mutable_data(dev_ctx.GetPlace()); - pten::SubtractKernel(dev_ctx, src1, src2, -1, out); + pten::SubtractKernel< + T, typename paddle::framework::ConvertToPtenContext::TYPE>( + static_cast::TYPE&>(dev_ctx), + src1, src2, -1, out); } template diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index be9cf1e3448..80b7acc6103 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. #pragma once +#include "paddle/pten/backends/cpu/cpu_context.h" #ifdef PADDLE_WITH_MKLML #include #endif @@ -819,6 +820,12 @@ T *Blas::GEMM_ALLOC(const CBLAS_IDENTIFIER id, const int K) const { return CBlas::GEMM_ALLOC(id, M, N, K); } +template <> +template +T *Blas::GEMM_ALLOC(const CBLAS_IDENTIFIER id, const int M, + const int N, const int K) const { + return CBlas::GEMM_ALLOC(id, M, N, K); +} template <> template @@ -829,6 +836,15 @@ void Blas::GEMM_PACK(const CBLAS_IDENTIFIER id, const int ld, T *dst) const { CBlas::GEMM_PACK(CblasRowMajor, id, trans, M, N, K, alpha, src, ld, dst); } +template <> +template +void Blas::GEMM_PACK(const CBLAS_IDENTIFIER id, + const CBLAS_TRANSPOSE trans, int M, + int N, int K, const T alpha, + const T *src, const int ld, + T *dst) const { + CBlas::GEMM_PACK(CblasRowMajor, id, trans, M, N, K, alpha, src, ld, dst); +} template <> template @@ -838,12 +854,26 @@ void Blas::GEMM_COMPUTE( CBlas::GEMM_COMPUTE(CblasRowMajor, transA, transB, M, N, K, A, lda, B, ldb, beta, C, ldc); } +template <> +template +void Blas::GEMM_COMPUTE(int transA, int transB, int M, int N, + int K, const T *A, const int lda, + const T *B, const int ldb, T beta, + T *C, const int ldc) const { + CBlas::GEMM_COMPUTE(CblasRowMajor, transA, transB, M, N, K, A, lda, B, ldb, + beta, C, ldc); +} template <> template void Blas::GEMM_FREE(T *data) const { CBlas::GEMM_FREE(data); } +template <> +template +void Blas::GEMM_FREE(T *data) const { + CBlas::GEMM_FREE(data); +} #endif template <> @@ -858,6 +888,18 @@ void Blas::GEMM(CBLAS_TRANSPOSE transA, CBlas::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); } +template <> +template +void Blas::GEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, int M, int N, int K, + T alpha, const T *A, const T *B, T beta, + T *C) const { + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + int ldc = N; + CBlas::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, + beta, C, ldc); +} template <> template @@ -869,6 +911,15 @@ void Blas::GEMM(bool transA, bool transB, int M, transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); } +template <> +template +void Blas::GEMM(bool transA, bool transB, int M, int N, int K, + T alpha, const T *A, int lda, const T *B, + int ldb, T beta, T *C, int ldc) const { + CBlas::GEMM(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans, + transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A, + lda, B, ldb, beta, C, ldc); +} template <> template @@ -880,6 +931,15 @@ void Blas::GEMM(CBLAS_TRANSPOSE transA, CBlas::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); } +template <> +template +void Blas::GEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, int M, int N, int K, + T alpha, const T *A, int lda, const T *B, + int ldb, T beta, T *C, int ldc) const { + CBlas::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, + beta, C, ldc); +} template template @@ -920,12 +980,22 @@ void Blas::AXPY(int n, T alpha, const T *x, T *y) const { CBlas::AXPY(n, alpha, x, 1, y, 1); } +template <> +template +void Blas::AXPY(int n, T alpha, const T *x, T *y) const { + CBlas::AXPY(n, alpha, x, 1, y, 1); +} template <> template void Blas::VCOPY(int n, const T *x, T *y) const { CBlas::VCOPY(n, x, 1, y, 1); } +template <> +template +void Blas::VCOPY(int n, const T *x, T *y) const { + CBlas::VCOPY(n, x, 1, y, 1); +} template <> template @@ -942,6 +1012,20 @@ void Blas::VADD(int n, const T *x, const T *y, } #endif } +template <> +template +void Blas::VADD(int n, const T *x, const T *y, T *z) const { +#ifdef PADDLE_WITH_MKLML + CBlas::VADD(n, x, y, z); +#else + if (x == z) { + this->template AXPY(n, (T)(1.), y, z); + } else { + this->template VCOPY(n, y, z); + this->template AXPY(n, (T)(1.), x, z); + } +#endif +} template <> template @@ -956,6 +1040,18 @@ void Blas::VSUB(int n, const T *x, const T *y, } #endif } +template <> +template +void Blas::VSUB(int n, const T *x, const T *y, T *z) const { +#ifdef PADDLE_WITH_MKLML + CBlas::VSUB(n, x, y, z); +#else + // try to find if openblas support vsub + for (int i = 0; i < n; ++i) { + z[i] = x[i] - y[i]; + } +#endif +} template <> template @@ -970,6 +1066,18 @@ void Blas::VMUL(int n, const T *x, const T *y, } #endif } +template <> +template +void Blas::VMUL(int n, const T *x, const T *y, T *z) const { +#ifdef PADDLE_WITH_MKLML + CBlas::VMUL(n, x, y, z); +#else + // try to find if openblas support vmul + for (int i = 0; i < n; ++i) { + z[i] = x[i] * y[i]; + } +#endif +} template <> template @@ -984,6 +1092,18 @@ void Blas::VDIV(int n, const T *x, const T *y, } #endif } +template <> +template +void Blas::VDIV(int n, const T *x, const T *y, T *z) const { +#ifdef PADDLE_WITH_MKLML + CBlas::VDIV(n, x, y, z); +#else + // try to find if openblas support vdiv + for (int i = 0; i < n; ++i) { + z[i] = x[i] / y[i]; + } +#endif +} template <> template @@ -997,6 +1117,18 @@ void Blas::VEXP(int n, const T *x, T *y) const { } #endif } +template <> +template +void Blas::VEXP(int n, const T *x, T *y) const { +#ifdef PADDLE_WITH_MKLML + CBlas::VEXP(n, x, y); +#else + // try to find if openblas support vexp + for (int i = 0; i < n; ++i) { + y[i] = std::exp(x[i]); + } +#endif +} template <> template @@ -1009,6 +1141,17 @@ void Blas::VSQUARE(int n, const T *x, T *y) const { } #endif } +template <> +template +void Blas::VSQUARE(int n, const T *x, T *y) const { +#ifdef PADDLE_WITH_MKLML + CBlas::VSQUARE(n, x, y); +#else + for (int i = 0; i < n; ++i) { + y[i] = x[i] * x[i]; + } +#endif +} template <> template @@ -1022,6 +1165,17 @@ void Blas::VPOW(int n, const T *x, T a, } #endif } +template <> +template +void Blas::VPOW(int n, const T *x, T a, T *y) const { +#ifdef PADDLE_WITH_MKLML + CBlas::VPOW(n, x, a, y); +#else + for (int i = 0; i < n; ++i) { + y[i] = std::pow(x[i], a); + } +#endif +} template <> template @@ -1037,6 +1191,20 @@ T Blas::DOT(int n, const T *x, const T *y) const { return sum; #endif } +template <> +template +T Blas::DOT(int n, const T *x, const T *y) const { +#ifdef PADDLE_WITH_MKLML + return CBlas::DOT(n, x, 1, y, 1); +#else + // try to find if openblas support cblas_dot + T sum = 0; + for (int i = 0; i < n; ++i) { + sum += x[i] * y[i]; + } + return sum; +#endif +} template <> template @@ -1050,6 +1218,18 @@ void Blas::SCAL(int n, const T a, T *x) const { } #endif } +template <> +template +void Blas::SCAL(int n, const T a, T *x) const { +#ifdef PADDLE_WITH_MKLML + CBlas::SCAL(n, a, x, 1); +#else + // try to find if openblas support cblas_scal + for (int i = 0; i < n; ++i) { + x[i] = a * x[i]; + } +#endif +} template <> template @@ -1065,6 +1245,20 @@ T Blas::ASUM(int n, T *x, int inc) const { #endif return sum; } +template <> +template +T Blas::ASUM(int n, T *x, int inc) const { + auto sum = static_cast(0.0); +#ifdef PADDLE_WITH_MKLML + sum = CBlas::ASUM(n, x, inc); +#else + // TODO(jczaja): check if openblas does provide cblas_sasum/cblas_dasum + for (int c = 0; c < n; ++c) { + sum += x[c]; + } +#endif + return sum; +} template <> template @@ -1074,6 +1268,13 @@ void Blas::GEMV(bool trans_a, int M, int N, T alpha, CBLAS_TRANSPOSE transA = !trans_a ? CblasNoTrans : CblasTrans; CBlas::GEMV(CblasRowMajor, transA, M, N, alpha, A, N, B, 1, beta, C, 1); } +template <> +template +void Blas::GEMV(bool trans_a, int M, int N, T alpha, + const T *A, const T *B, T beta, T *C) const { + CBLAS_TRANSPOSE transA = !trans_a ? CblasNoTrans : CblasTrans; + CBlas::GEMV(CblasRowMajor, transA, M, N, alpha, A, N, B, 1, beta, C, 1); +} template <> template @@ -1112,6 +1313,45 @@ void Blas::BatchedGEMM( } #endif } +template <> +template +void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, int M, int N, + int K, T alpha, const T *A, const T *B, + T beta, T *C, int batchCount, + int64_t strideA, + int64_t strideB) const { + PADDLE_ENFORCE_NOT_NULL( + A, platform::errors::InvalidArgument("Pointer A should not be null.")); + PADDLE_ENFORCE_NOT_NULL( + B, platform::errors::InvalidArgument("Pointer B should not be null.")); + PADDLE_ENFORCE_NOT_NULL( + C, platform::errors::InvalidArgument("Pointer C should not be null.")); +#ifdef PADDLE_WITH_MKLML + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + int ldc = N; + auto a_array = std::vector(batchCount); + auto b_array = std::vector(batchCount); + auto c_array = std::vector(batchCount); + for (int k = 0; k < batchCount; ++k) { + a_array[k] = &A[k * strideA]; + b_array[k] = &B[k * strideB]; + c_array[k] = &C[k * M * N]; + } + + CBlas::GEMM_BATCH(CblasRowMajor, &transA, &transB, &M, &N, &K, &alpha, + a_array.data(), &lda, b_array.data(), &ldb, &beta, + c_array.data(), &ldc, 1 /* group_count */, &batchCount); +#else + for (int k = 0; k < batchCount; ++k) { + auto *Ak = &A[k * strideA]; + auto *Bk = &B[k * strideB]; + auto *Ck = &C[k * M * N]; + this->template GEMM(transA, transB, M, N, K, alpha, Ak, Bk, beta, Ck); + } +#endif +} template <> template @@ -1132,6 +1372,27 @@ void Blas::BatchedGEMM( } #endif } +template <> +template +void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, int M, int N, + int K, T alpha, const T **A, + const T **B, T beta, T **C, + int batchCount) const { +#ifdef PADDLE_WITH_MKLML + const int lda = (std::max)((transA == CblasNoTrans) ? K : M, 1); + const int ldb = (std::max)((transB == CblasNoTrans) ? N : K, 1); + const int ldc = (std::max)(N, 1); + CBlas::GEMM_BATCH(CblasRowMajor, &transA, &transB, &M, &N, &K, &alpha, A, + &lda, B, &ldb, &beta, C, &ldc, 1 /* group_count */, + &batchCount); +#else + for (int k = 0; k < batchCount; ++k) { + this->template GEMM(transA, transB, M, N, K, alpha, A[k], B[k], beta, + C[k]); + } +#endif +} #if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) && \ !defined(PADDLE_WITH_HIP) // @{ Group Blas MKLML: BatchedGEMMWithHead @@ -1204,6 +1465,75 @@ void Blas::BatchedGEMMWithHead( } } } +template <> +template +void Blas::BatchedGEMMWithHead( + CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int W1, int H1, int W2, + int H2, T alpha, const T *A, const T *B, T beta, T *C, int batchCount, + int64_t strideA, int64_t strideB, int64_t head_number, + bool split_b_vertical) const { + int lda = (transA == CblasNoTrans) ? W1 : H1; + int ldb = (transB == CblasNoTrans) ? W2 : H2; + auto a_array = std::vector(batchCount); + auto b_array = std::vector(batchCount); + auto c_array = std::vector(batchCount); + + if (split_b_vertical) { + int ldc = W2; + int sub_width = W2 / head_number; + + for (int i = 0; i < head_number; i++) { + int sub_matA_offset = (transA == CblasNoTrans) + ? i * (W1 / head_number) + : i * (W1 / head_number) * H1; + int sub_matB_offset = (transB == CblasNoTrans) + ? i * (W2 / head_number) + : i * (W2 / head_number) * H2; + int sub_matC_offset = i * W2 / head_number; + for (int k = 0; k < batchCount; ++k) { + a_array[k] = &A[k * strideA] + sub_matA_offset; + b_array[k] = &B[k * strideB] + sub_matB_offset; + c_array[k] = &C[k * H1 * W2] + sub_matC_offset; + } + + CBlas::GEMM_BATCH(CblasRowMajor, &transA, &transB, &H1, &sub_width, + &H2, &alpha, a_array.data(), &lda, b_array.data(), + &ldb, &beta, c_array.data(), &ldc, + 1 /* group_count */, &batchCount); + } + + } else { + PADDLE_ENFORCE_EQ( + W1, H2, + platform::errors::InvalidArgument( + "The fisrt matrix width should be same as second matrix height," + "but received fisrt matrix width %d" + ", second matrix height %d", + W1, H2)); + int ldc = W2 * head_number; + int sub_width = W1 / head_number; + + for (int i = 0; i < head_number; i++) { + int sub_matA_offset = (transA == CblasNoTrans) + ? i * (W1 / head_number) + : i * (W1 / head_number) * H1; + int sub_matB_offset = (transB == CblasNoTrans) + ? i * (W1 / head_number) * W2 + : i * (W1 / head_number); + int sub_matC_offset = i * W2; + for (int k = 0; k < batchCount; ++k) { + a_array[k] = &A[k * strideA] + sub_matA_offset; + b_array[k] = &B[k * strideB] + sub_matB_offset; + c_array[k] = &C[k * H1 * head_number * W2] + sub_matC_offset; + } + + CBlas::GEMM_BATCH(CblasRowMajor, &transA, &transB, &H1, &W2, + &sub_width, &alpha, a_array.data(), &lda, + b_array.data(), &ldb, &beta, c_array.data(), &ldc, + 1 /* group_count */, &batchCount); + } + } +} #endif // @} End Group Blas MKLML: BatchedGEMMWithHead template @@ -1241,6 +1571,31 @@ void Blas::MatMul(const int M, const int N, CBlas::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K, static_cast(1), A, K, B, N, static_cast(0), C, N); } +template <> +template +void Blas::MatMul(const int M, const int N, const int K, + const T *A, const T *B, T *C) const { +#ifdef PADDLE_WITH_LIBXSMM + // Refer to https://github.com/hfp/libxsmm/blob/master/README.md + // But the threshold is custom constexpr int LIBXSMM_THRESHOLD = 20 * 20 * 20; + + // Since the matrix is very small, + // so the unit of calculation is already very fast, + // and the if( M*N*K < LIBXSMM_THRESHOLD) would be overhead, + // use xsmm directly. + // Note: SMM use ColMajor + const char transa = 'N'; + const char transb = 'N'; + const T alpha = static_cast(1); + const T beta = static_cast(0); + CBlas::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &N, A, &K, &beta, + C, &N); + return; +#endif + + CBlas::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K, + static_cast(1), A, K, B, N, static_cast(0), C, N); +} template template @@ -1443,6 +1798,18 @@ void Blas::VMERF(int n, const T *a, T *y, } #endif } +template <> +template +void Blas::VMERF(int n, const T *a, T *y, + int64_t mode) const { +#ifdef PADDLE_WITH_MKLML + CBlas::VMERF(n, a, y, mode); +#else + for (int i = 0; i < n; ++i) { + y[i] = std::erf(a[i]); + } +#endif +} #ifdef PADDLE_WITH_MKLML template <> @@ -1455,6 +1822,17 @@ void Blas::CSRMM( CBlas::CSRMM(transa, m, n, k, alpha, matdescra, val, indx, pntrb, pntre, b, ldb, beta, c, ldc); } +template <> +template +void Blas::CSRMM(const char *transa, const int *m, + const int *n, const int *k, const T *alpha, + const char *matdescra, const T *val, + const int *indx, const int *pntrb, + const int *pntre, const T *b, const int *ldb, + const T *beta, T *c, const int *ldc) const { + CBlas::CSRMM(transa, m, n, k, alpha, matdescra, val, indx, pntrb, pntre, b, + ldb, beta, c, ldc); +} #endif template <> @@ -1467,6 +1845,15 @@ void Blas::TRSM(CBLAS_SIDE side, CBLAS_UPLO uplo, CBlas::TRSM(CblasRowMajor, side, uplo, transA, diag, M, N, alpha, A, lda, B, ldb); } +template <> +template +void Blas::TRSM(CBLAS_SIDE side, CBLAS_UPLO uplo, + CBLAS_TRANSPOSE transA, CBLAS_DIAG diag, + int M, int N, T alpha, const T *A, int lda, + T *B, int ldb) const { + CBlas::TRSM(CblasRowMajor, side, uplo, transA, diag, M, N, alpha, A, lda, + B, ldb); +} } // namespace math } // namespace operators diff --git a/paddle/fluid/operators/math/math_function.cu b/paddle/fluid/operators/math/math_function.cu index 960453dbe65..9ade45ee743 100644 --- a/paddle/fluid/operators/math/math_function.cu +++ b/paddle/fluid/operators/math/math_function.cu @@ -44,6 +44,22 @@ template struct SetConstant>; +template struct SetConstant; +template struct SetConstant; +template struct SetConstant; +template struct SetConstant; +template struct SetConstant; +template struct SetConstant; +template struct SetConstant; +template struct SetConstant; +template struct SetConstant; +template struct SetConstant>; +template struct SetConstant>; + #define DEFINE_GPU_TRANS(RANK) \ template struct Transpose; \ template struct Transpose; \ diff --git a/paddle/fluid/operators/matmul_v2_op.h b/paddle/fluid/operators/matmul_v2_op.h index 67b0fec734c..0e1c6b82e41 100644 --- a/paddle/fluid/operators/matmul_v2_op.h +++ b/paddle/fluid/operators/matmul_v2_op.h @@ -53,7 +53,10 @@ class MatMulV2Kernel : public framework::OpKernel { Out->mutable_data(X->place()); // call new kernel - pten::MatmulKernel(dev_ctx, *X, *Y, trans_x, trans_y, Out); + pten::MatmulKernel( + static_cast::TYPE&>(dev_ctx), + *X, *Y, trans_x, trans_y, Out); } }; @@ -149,8 +152,10 @@ class MatMulV2GradKernel : public framework::OpKernel { auto& dev_ctx = ctx.device_context(); // call new kernel - pten::MatmulGradKernel(dev_ctx, *x, *y, *dout, transpose_x, transpose_y, - dx, dy); + pten::MatmulGradKernel( + static_cast::TYPE&>(dev_ctx), + *x, *y, *dout, transpose_x, transpose_y, dx, dy); } }; @@ -178,8 +183,10 @@ class MatMulV2DoubleGradKernel : public framework::OpKernel { auto& dev_ctx = context.device_context(); // call new kernel - pten::MatmulDoubleGradKernel(dev_ctx, *x, *y, *dout, *ddx, *ddy, - transpose_x, transpose_y, dx, dy, ddout); + pten::MatmulDoubleGradKernel( + static_cast::TYPE&>(dev_ctx), + *x, *y, *dout, *ddx, *ddy, transpose_x, transpose_y, dx, dy, ddout); } }; @@ -218,7 +225,9 @@ class MatMulV2TripleGradKernel : public framework::OpKernel { auto& dev_ctx = context.device_context(); // call new kernel pten::MatmulTripleGradKernel( - dev_ctx, *x, *y, *dout, *ddx, *ddy, *d_dx, *d_dy, *d_ddout, transpose_x, + static_cast::TYPE&>(dev_ctx), + *x, *y, *dout, *ddx, *ddy, *d_dx, *d_dy, *d_ddout, transpose_x, transpose_y, out_d_x, out_d_y, out_d_dout, out_d_ddx, out_d_ddy); } }; diff --git a/paddle/fluid/operators/memcpy_d2h_op.h b/paddle/fluid/operators/memcpy_d2h_op.h index fb5610dda70..e1b81c0c592 100644 --- a/paddle/fluid/operators/memcpy_d2h_op.h +++ b/paddle/fluid/operators/memcpy_d2h_op.h @@ -16,12 +16,6 @@ limitations under the License. */ #include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/platform/device_context.h" -namespace paddle { -namespace platform { -class DeviceContext; -} // namespace platform -} // namespace paddle - namespace pten { class DenseTensor; } // namespace pten diff --git a/paddle/fluid/operators/memcpy_h2d_op.h b/paddle/fluid/operators/memcpy_h2d_op.h index e84dedd9112..7f487001040 100644 --- a/paddle/fluid/operators/memcpy_h2d_op.h +++ b/paddle/fluid/operators/memcpy_h2d_op.h @@ -17,12 +17,6 @@ limitations under the License. */ #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/stream/stream.h" -namespace paddle { -namespace platform { -class DeviceContext; -} // namespace platform -} // namespace paddle - namespace pten { class DenseTensor; } // namespace pten diff --git a/paddle/fluid/operators/memcpy_op.h b/paddle/fluid/operators/memcpy_op.h index d2a081ac3c2..ac4a0d1ab11 100644 --- a/paddle/fluid/operators/memcpy_op.h +++ b/paddle/fluid/operators/memcpy_op.h @@ -19,12 +19,6 @@ limitations under the License. */ #include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/platform/device_context.h" -namespace paddle { -namespace platform { -class DeviceContext; -} // namespace platform -} // namespace paddle - namespace pten { class DenseTensor; } // namespace pten diff --git a/paddle/fluid/operators/pscore/heter_listen_and_serv_op.h b/paddle/fluid/operators/pscore/heter_listen_and_serv_op.h index f81b45ec05f..77c755581f9 100644 --- a/paddle/fluid/operators/pscore/heter_listen_and_serv_op.h +++ b/paddle/fluid/operators/pscore/heter_listen_and_serv_op.h @@ -45,9 +45,6 @@ class Executor; class ProgramDesc; class Scope; } // namespace framework -namespace platform { -class DeviceContext; -} // namespace platform } // namespace paddle namespace paddle { diff --git a/paddle/fluid/operators/recurrent_op.h b/paddle/fluid/operators/recurrent_op.h index e3f512d45c0..1ca66527e1b 100644 --- a/paddle/fluid/operators/recurrent_op.h +++ b/paddle/fluid/operators/recurrent_op.h @@ -22,12 +22,6 @@ limitations under the License. */ #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/op_registry.h" -namespace paddle { -namespace platform { -class DeviceContext; -} // namespace platform -} // namespace paddle - namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index 25603b07c7a..e2002856a4d 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -258,8 +258,11 @@ class ReduceKernel : public framework::OpKernel { std::vector tmp_dims(dims.begin(), dims.end()); // call new kernel - pten::Reduce( - dev_ctx, *pt_x.get(), reduce_all, tmp_dims, keep_dim, + pten::Reduce::TYPE, + T, Functor>( + static_cast::TYPE&>(dev_ctx), + *pt_x.get(), reduce_all, tmp_dims, keep_dim, pten::TransToPtenDataType(cast_out_dtype), pt_out.get()); } }; diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index 4c9be6d0ccb..dc82d7c6c1e 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -19,6 +19,7 @@ limitations under the License. */ // only can include the headers in paddle/pten/api dirs #include "paddle/pten/api/lib/utils/tensor_utils.h" +#include "paddle/pten/backends/cpu/cpu_context.h" #include "paddle/pten/common/scalar_array.h" #include "paddle/pten/kernels/reshape_grad_kernel.h" #include "paddle/pten/kernels/reshape_kernel.h" @@ -435,7 +436,8 @@ class ReshapeKernel { } if (platform::is_cpu_place(ctx.GetPlace())) { auto &dev_ctx = ctx.device_context(); - pten::ReshapeKernel(dev_ctx, *pt_x.get(), pt_scalar_shape, pt_out); + pten::ReshapeKernel(static_cast(dev_ctx), + *pt_x.get(), pt_scalar_shape, pt_out); } #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (platform::is_gpu_place(ctx.GetPlace())) { @@ -471,7 +473,8 @@ class ReshapeGradKernel { if (platform::is_cpu_place(ctx.GetPlace())) { auto &dev_ctx = ctx.device_context(); - pten::ReshapeGradKernel(dev_ctx, *pt_d_out.get(), pt_d_x.get()); + pten::ReshapeGradKernel(static_cast(dev_ctx), + *pt_d_out.get(), pt_d_x.get()); } #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (platform::is_gpu_place(ctx.GetPlace())) { @@ -500,7 +503,9 @@ class ReshapeDoubleGradKernel { if (platform::is_cpu_place(ctx.GetPlace())) { auto &dev_ctx = ctx.device_context(); - pten::ReshapeDoubleGradKernel(dev_ctx, *pt_dd_x.get(), pt_dd_out.get()); + pten::ReshapeDoubleGradKernel( + static_cast(dev_ctx), *pt_dd_x.get(), + pt_dd_out.get()); } #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (platform::is_gpu_place(ctx.GetPlace())) { diff --git a/paddle/fluid/operators/scale_op.h b/paddle/fluid/operators/scale_op.h index 4faa23b6c16..a04837b6949 100644 --- a/paddle/fluid/operators/scale_op.h +++ b/paddle/fluid/operators/scale_op.h @@ -67,7 +67,10 @@ class ScaleKernel : public framework::OpKernel { auto& dev_ctx = ctx.device_context(); // call new kernel - pten::ScaleKernel(dev_ctx, *in, scale, bias, bias_after_scale, out); + pten::ScaleKernel( + static_cast::TYPE&>(dev_ctx), + *in, scale, bias, bias_after_scale, out); } }; diff --git a/paddle/fluid/operators/seed_op.cu b/paddle/fluid/operators/seed_op.cu index 2154b08ae86..5a8d1c067c3 100644 --- a/paddle/fluid/operators/seed_op.cu +++ b/paddle/fluid/operators/seed_op.cu @@ -30,7 +30,7 @@ class GPUSeedKernel : public framework::OpKernel { if (cpu_place) { platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto &dev_ctx = *pool.Get(context.GetPlace()); + auto &dev_ctx = *pool.Get(platform::CPUPlace()); out->mutable_data(platform::CPUPlace()); math::SetConstant functor; functor(reinterpret_cast(dev_ctx), diff --git a/paddle/fluid/operators/sign_op.h b/paddle/fluid/operators/sign_op.h index 737560d314a..41bcf9e8ae1 100644 --- a/paddle/fluid/operators/sign_op.h +++ b/paddle/fluid/operators/sign_op.h @@ -35,7 +35,11 @@ class SignKernel : public framework::OpKernel { out->mutable_data(x->place()); // call new kernel - pten::SignKernel(dev_ctx, *x, out); + pten::SignKernel::TYPE>( + static_cast::TYPE&>(dev_ctx), + *x, out); } }; diff --git a/paddle/fluid/operators/transfer_layout_op.h b/paddle/fluid/operators/transfer_layout_op.h index 74d086015ee..cd3f7e70678 100644 --- a/paddle/fluid/operators/transfer_layout_op.h +++ b/paddle/fluid/operators/transfer_layout_op.h @@ -21,12 +21,6 @@ #include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/platform/device_context.h" -namespace paddle { -namespace platform { -class DeviceContext; -} // namespace platform -} // namespace paddle - namespace pten { class DenseTensor; } // namespace pten diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index 9e0a0cb5f8d..21531a3efd6 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -123,7 +123,7 @@ cc_library(cudnn_workspace_helper SRCS cudnn_workspace_helper.cc DEPS boost) # avoiding cycle dependencies cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool malloc xxhash ${STREAM_CALLBACK_DEPS} place pten_place eigen3 stringpiece cpu_helper cpu_info framework_proto ${IPU_CTX_DEPS} ${GPU_CTX_DEPS} ${NPU_CTX_DEPS} ${MKLDNN_CTX_DEPS} - ${dgc_deps} dlpack cudnn_workspace_helper ${XPU_CTX_DEPS} ${MLU_CTX_DEPS}) + ${dgc_deps} dlpack cudnn_workspace_helper ${XPU_CTX_DEPS} ${MLU_CTX_DEPS} cpu_context) cc_library(collective_helper SRCS collective_helper.cc gen_comm_id_helper.cc DEPS framework_proto device_context enforce) if(WITH_ASCEND_CL) diff --git a/paddle/fluid/platform/device/mlu/device_context.h b/paddle/fluid/platform/device/mlu/device_context.h index 67c547dc69a..2692f3a248a 100644 --- a/paddle/fluid/platform/device/mlu/device_context.h +++ b/paddle/fluid/platform/device/mlu/device_context.h @@ -21,8 +21,6 @@ struct DefaultDevice; struct GpuDevice; } // namespace Eigen -// class DeviceContext; - namespace paddle { namespace platform { diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 7347853fff1..27b900198bc 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -132,7 +132,7 @@ platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) { return it->second.get().get(); } -template +template inline void EmplaceDeviceContext( std::map>>* map_ptr, @@ -158,19 +158,14 @@ DeviceContextPool::DeviceContextPool( } for (auto& p : set) { if (platform::is_cpu_place(p)) { - platform::CPUPlace place; #ifdef PADDLE_WITH_MKLDNN - EmplaceDeviceContext(&device_contexts_, - place); + EmplaceDeviceContext(&device_contexts_, p); #else - EmplaceDeviceContext(&device_contexts_, - place); + EmplaceDeviceContext(&device_contexts_, p); #endif } else if (platform::is_gpu_place(p)) { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - platform::CUDAPlace place(p.GetDeviceId()); - EmplaceDeviceContext(&device_contexts_, - place); + EmplaceDeviceContext(&device_contexts_, p); #else PADDLE_THROW( platform::errors::Unimplemented("CUDAPlace is not supported. Please " @@ -178,9 +173,7 @@ DeviceContextPool::DeviceContextPool( #endif } else if (platform::is_cuda_pinned_place(p)) { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - platform::CUDAPinnedPlace place; - EmplaceDeviceContext( - &device_contexts_, place); + EmplaceDeviceContext(&device_contexts_, p); #else PADDLE_THROW(platform::errors::Unimplemented( "CUDAPlace is not supported. Please re-compile with WITH_GPU " @@ -188,9 +181,7 @@ DeviceContextPool::DeviceContextPool( #endif } else if (platform::is_xpu_place(p)) { #ifdef PADDLE_WITH_XPU - platform::XPUPlace place(p.GetDeviceId()); - EmplaceDeviceContext(&device_contexts_, - place); + EmplaceDeviceContext(&device_contexts_, p); #else PADDLE_THROW( platform::errors::Unimplemented("XPUPlace is not supported. Please " @@ -198,9 +189,7 @@ DeviceContextPool::DeviceContextPool( #endif } else if (platform::is_mlu_place(p)) { #ifdef PADDLE_WITH_MLU - platform::MLUPlace place(p.GetDeviceId()); - EmplaceDeviceContext(&device_contexts_, - place); + EmplaceDeviceContext(&device_contexts_, p); #else PADDLE_THROW( platform::errors::Unimplemented("MLUPlace is not supported. Please " @@ -208,9 +197,7 @@ DeviceContextPool::DeviceContextPool( #endif } else if (platform::is_ipu_place(p)) { #ifdef PADDLE_WITH_IPU - platform::IPUPlace place(p.GetDeviceId()); - EmplaceDeviceContext(&device_contexts_, - place); + EmplaceDeviceContext(&device_contexts_, p); #else PADDLE_THROW( platform::errors::Unimplemented("IPUPlace is not supported. Please " @@ -218,9 +205,7 @@ DeviceContextPool::DeviceContextPool( #endif } else if (platform::is_npu_place(p)) { #ifdef PADDLE_WITH_ASCEND_CL - platform::NPUPlace place(p.GetDeviceId()); - EmplaceDeviceContext(&device_contexts_, - place); + EmplaceDeviceContext(&device_contexts_, p); #else PADDLE_THROW(platform::errors::Unimplemented( "NPUPlace is not supported. Please " @@ -228,9 +213,7 @@ DeviceContextPool::DeviceContextPool( #endif } else if (platform::is_npu_pinned_place(p)) { #ifdef PADDLE_WITH_ASCEND_CL - platform::NPUPinnedPlace place; - EmplaceDeviceContext( - &device_contexts_, place); + EmplaceDeviceContext(&device_contexts_, p); #else PADDLE_THROW(platform::errors::Unimplemented( "NPUPinnedPlace is not supported. Please re-compile with " @@ -241,19 +224,9 @@ DeviceContextPool::DeviceContextPool( } } -CPUDeviceContext::CPUDeviceContext() { - eigen_device_.reset(new Eigen::DefaultDevice()); -} - -CPUDeviceContext::CPUDeviceContext(CPUPlace place) : place_(place) { - eigen_device_.reset(new Eigen::DefaultDevice()); -} - -Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const { - return eigen_device_.get(); -} +CPUDeviceContext::CPUDeviceContext() : pten::CPUContext() {} -Place CPUDeviceContext::GetPlace() const { return place_; } +CPUDeviceContext::CPUDeviceContext(CPUPlace place) : pten::CPUContext() {} #ifdef PADDLE_WITH_IPU IPUDeviceContext::IPUDeviceContext(IPUPlace place) : place_(place) { diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index e1fcc3ae900..78c09dca5b4 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -18,6 +18,9 @@ limitations under the License. */ #include #include +#include "paddle/pten/backends/cpu/cpu_context.h" +#include "paddle/pten/core/device_context.h" + #include "paddle/fluid/memory/malloc.h" #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/device/gpu/gpu_helper.h" @@ -117,26 +120,15 @@ constexpr DeviceType kNPU = DeviceType::NPU; constexpr DeviceType kIPU = DeviceType::IPU; constexpr DeviceType kMLU = DeviceType::MLU; -class DeviceContext { - public: - virtual ~DeviceContext() PADDLE_MAY_THROW {} - virtual Place GetPlace() const = 0; - - virtual void Wait() const {} -}; +using DeviceContext = pten::DeviceContext; -class CPUDeviceContext : public DeviceContext { +// using CPUDeviceContext = pten::CPUContext; +// TODO(wilber): The place constructor is used in many places, it is more +// difficult to use CPUDeviceContext = pten::CPUContext directly. +class CPUDeviceContext : public pten::CPUContext { public: CPUDeviceContext(); explicit CPUDeviceContext(CPUPlace place); - - Eigen::DefaultDevice* eigen_device() const; - - Place GetPlace() const override; - - private: - CPUPlace place_; - std::unique_ptr eigen_device_; }; template diff --git a/paddle/fluid/platform/for_range.h b/paddle/fluid/platform/for_range.h index 5518dabbf92..1d4be3801dd 100644 --- a/paddle/fluid/platform/for_range.h +++ b/paddle/fluid/platform/for_range.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" #include "paddle/fluid/platform/device_context.h" +#include "paddle/pten/backends/cpu/cpu_context.h" namespace paddle { namespace platform { @@ -27,6 +28,7 @@ struct ForRange { void operator()(Function func) const; }; +// NOTE: After the pten kernel is migrated, it needs to be deleted. template <> struct ForRange { ForRange(const CPUDeviceContext& dev_ctx, size_t limit) : limit_(limit) {} @@ -41,6 +43,20 @@ struct ForRange { size_t limit_; }; +template <> +struct ForRange { + ForRange(const pten::CPUContext& dev_ctx, size_t limit) : limit_(limit) {} + + template + void operator()(Function func) const { + for (size_t i = 0; i < limit_; ++i) { + func(i); + } + } + + size_t limit_; +}; + #if defined(__NVCC__) || defined(__HIPCC__) template __global__ static void ForRangeElemwiseOpGridIsOne(Function func) { diff --git a/paddle/fluid/platform/transform.h b/paddle/fluid/platform/transform.h index 81c9909df77..cc9919d8366 100644 --- a/paddle/fluid/platform/transform.h +++ b/paddle/fluid/platform/transform.h @@ -59,6 +59,7 @@ struct Transform { BinaryOperation op); }; +// NOTE: After the pten kernel is migrated, it needs to be deleted. template <> struct Transform { template @@ -76,6 +77,23 @@ struct Transform { } }; +template <> +struct Transform { + template + void operator()(const pten::CPUContext& context, InputIter first, + InputIter last, OutputIter result, UnaryOperation op) { + std::transform(first, last, result, op); + } + + template + void operator()(const pten::CPUContext& context, InputIter1 first1, + InputIter1 last1, InputIter2 first2, OutputIter result, + BinaryOperation op) { + std::transform(first1, last1, first2, result, op); + } +}; + #if defined(__NVCC__) || defined(__HIPCC__) template <> struct Transform { diff --git a/paddle/pten/backends/CMakeLists.txt b/paddle/pten/backends/CMakeLists.txt index af49ad61a48..e45adefe652 100644 --- a/paddle/pten/backends/CMakeLists.txt +++ b/paddle/pten/backends/CMakeLists.txt @@ -1 +1,2 @@ +add_subdirectory(cpu) cc_library(pten_context SRCS all_context.cc DEPS device_context) diff --git a/paddle/pten/backends/all_context.h b/paddle/pten/backends/all_context.h index 8cc07d216c0..52a216e4f7d 100644 --- a/paddle/pten/backends/all_context.h +++ b/paddle/pten/backends/all_context.h @@ -24,7 +24,9 @@ limitations under the License. */ #include "paddle/pten/backends/gpu/gpu_context.h" #include "paddle/pten/backends/xpu/xpu_context.h" +// TODO(wilber): DeviceContextPool nees include fluid file. +#include "paddle/fluid/platform/device_context.h" + namespace pten { -using DeviceContext = paddle::platform::DeviceContext; using DeviceContextPool = paddle::platform::DeviceContextPool; } // namespace pten diff --git a/paddle/pten/backends/cpu/CMakeLists.txt b/paddle/pten/backends/cpu/CMakeLists.txt new file mode 100644 index 00000000000..62eff2dedc9 --- /dev/null +++ b/paddle/pten/backends/cpu/CMakeLists.txt @@ -0,0 +1,6 @@ +if(WITH_MKLDNN) + # TODO(wilber): support mkldnn context. + cc_library(cpu_context SRCS cpu_context.cc DEPS pten_device_context mkldnn) +else() + cc_library(cpu_context SRCS cpu_context.cc DEPS pten_device_context) +endif() diff --git a/paddle/pten/backends/cpu/cpu_context.cc b/paddle/pten/backends/cpu/cpu_context.cc new file mode 100644 index 00000000000..e749dfb9bd7 --- /dev/null +++ b/paddle/pten/backends/cpu/cpu_context.cc @@ -0,0 +1,93 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/backends/cpu/cpu_context.h" + +#include "paddle/pten/api/ext/exception.h" + +// NOTE: The paddle framework should add WITH_EIGEN option to support compile +// without eigen. +#include "paddle/pten/core/device_context.h" +#include "unsupported/Eigen/CXX11/Tensor" + +namespace pten { + +struct CPUContext::CPUImpl { + Eigen::DefaultDevice* device_{nullptr}; + CPUContextResource res_; + CPUPlace place_; + + CPUImpl() { device_ = new Eigen::DefaultDevice(); } + + // Users need to manage external resources. + explicit CPUImpl(const CPUContextResource& ctx_res) : res_(ctx_res) { + device_ = res_.device; + } + + ~CPUImpl() { + if (res_.device == nullptr) { + delete device_; + device_ = nullptr; + } + } + + Eigen::DefaultDevice* GetEigenDevice() const { + PD_CHECK(device_ != nullptr, "the eigen_device is nullptr."); + return device_; + } + + void SetEigenDevice(Eigen::DefaultDevice* device) { + if (device == nullptr) { + return; + } + res_.device = device; + device_ = device; + } + + Place GetPlace() const { return place_; } +}; + +CPUContext::CPUContext() : DeviceContext(), cpu_impl_(nullptr) { + cpu_impl_ = std::make_unique(); +} + +CPUContext::CPUContext(const CPUContext& other) + : DeviceContext(), cpu_impl_(nullptr) { + cpu_impl_ = std::make_unique(); + cpu_impl_->SetEigenDevice(other.eigen_device()); +} + +CPUContext::CPUContext(CPUContext&& other) + : DeviceContext(), cpu_impl_(nullptr) { + cpu_impl_ = std::move(other.cpu_impl_); +} + +CPUContext::~CPUContext() = default; + +CPUContext::CPUContext(const CPUContextResource& ctx_res) + : DeviceContext(), cpu_impl_(nullptr) { + cpu_impl_ = std::make_unique(ctx_res); +} + +Eigen::DefaultDevice* CPUContext::eigen_device() const { + return cpu_impl_->GetEigenDevice(); +} + +void CPUContext::SetEigenDevice(Eigen::DefaultDevice* device) { + cpu_impl_->SetEigenDevice(device); +} + +Place CPUContext::GetPlace() const { return cpu_impl_->GetPlace(); } + +} // namespace pten diff --git a/paddle/pten/backends/cpu/cpu_context.h b/paddle/pten/backends/cpu/cpu_context.h index b161e98d636..059588dc712 100644 --- a/paddle/pten/backends/cpu/cpu_context.h +++ b/paddle/pten/backends/cpu/cpu_context.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,9 +14,47 @@ limitations under the License. */ #pragma once -// See Note [ Why still include the fluid headers? ] -#include "paddle/fluid/platform/device_context.h" +#include + +#include "paddle/pten/backends/cpu/forwards.h" +#include "paddle/pten/core/device_context.h" + +// TODO(wilber): Do we need to use place in pten kernel? +#include "paddle/pten/common/place.h" namespace pten { -using CPUContext = paddle::platform::CPUDeviceContext; + +struct CPUContextResource { + Eigen::DefaultDevice* device{nullptr}; +}; + +class CPUContext : public DeviceContext { + public: + // NOTE: DeviceContext hold resources. Used in training scenarios. + CPUContext(); + + // NOTE: Share the same underlying resources, please ensure that resources are + // not released. + CPUContext(const CPUContext&); + + CPUContext(CPUContext&&); + + ~CPUContext(); + + Eigen::DefaultDevice* eigen_device() const; + + // TODO(wilber): Whether the interface should be preserved. + Place GetPlace() const override; + + public: + // NOTE: External users manage resources. Used in inference scenarios. + explicit CPUContext(const CPUContextResource& ctx_res); + + void SetEigenDevice(Eigen::DefaultDevice* device); + + private: + struct CPUImpl; + std::unique_ptr cpu_impl_; +}; + } // namespace pten diff --git a/paddle/pten/backends/cpu/forwards.h b/paddle/pten/backends/cpu/forwards.h new file mode 100644 index 00000000000..202a414893e --- /dev/null +++ b/paddle/pten/backends/cpu/forwards.h @@ -0,0 +1,21 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +// Forward-declares. +#pragma once + +// Forward declaration of Eigen DefaultDevice types. +namespace Eigen { +struct DefaultDevice; +} // namespace Eigen diff --git a/paddle/pten/core/CMakeLists.txt b/paddle/pten/core/CMakeLists.txt index b1e35b240cd..facc9ac0056 100644 --- a/paddle/pten/core/CMakeLists.txt +++ b/paddle/pten/core/CMakeLists.txt @@ -13,8 +13,10 @@ cc_library(arg_map_context SRCS arg_map_context.cc DEPS enforce) cc_library(tensor_base SRCS tensor_base.cc allocator.cc storage.cc DEPS enforce) cc_library(tensor_meta SRCS tensor_meta.cc DEPS enforce mixed_vector) cc_library(dense_tensor SRCS dense_tensor.cc DEPS convert_utils tensor_meta tensor_base) +cc_library(pten_device_context SRCS device_context.cc DEPS tensor_base ) # Will remove once we implemented MKLDNN_Tensor if(WITH_MKLDNN) add_dependencies(dense_tensor mkldnn) + add_dependencies(tensor_base mkldnn) endif() diff --git a/paddle/pten/core/device_context.cc b/paddle/pten/core/device_context.cc new file mode 100644 index 00000000000..7b2c4a2cf17 --- /dev/null +++ b/paddle/pten/core/device_context.cc @@ -0,0 +1,56 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/core/device_context.h" + +namespace pten { + +struct DeviceContext::Impl { + Allocator* allocator_{nullptr}; + + Impl() = default; + ~Impl() = default; + + void SetAllocator(Allocator* allocator) { allocator_ = allocator; } + + const Allocator& GetAllocator() const { return *allocator_; } + + // TODO(Wilber): Add impl. It seems that tensorbase not have interface to + // communicate with allocator. + void Alloc(TensorBase* tensor) {} +}; + +DeviceContext::DeviceContext() { impl_ = std::make_unique(); } + +DeviceContext::DeviceContext(const DeviceContext& other) { + impl_->SetAllocator(const_cast(&other.GetAllocator())); +} + +DeviceContext::DeviceContext(DeviceContext&& other) { + impl_ = std::move(other.impl_); +} + +DeviceContext::~DeviceContext() = default; + +void DeviceContext::SetAllocator(Allocator* allocator) { + impl_->SetAllocator(allocator); +} + +const Allocator& DeviceContext::GetAllocator() const { + return impl_->GetAllocator(); +} + +void DeviceContext::Alloc(TensorBase* tensor) { impl_->Alloc(tensor); } + +} // namespace pten diff --git a/paddle/pten/core/device_context.h b/paddle/pten/core/device_context.h new file mode 100644 index 00000000000..bb851d954f2 --- /dev/null +++ b/paddle/pten/core/device_context.h @@ -0,0 +1,86 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +// TODO(wilber): Do we need to use place in pten kernel? +#include "paddle/pten/common/place.h" + +#include "paddle/pten/core/candidate/allocator.h" + +namespace pten { +class TensorBase; + +/** + * DeviceContext provides device-related interfaces. + * + * All kernels must access the interfaces provided by the backend through + * DeviceContext. + */ +class DeviceContext { + public: + /** + * @brief Default construct. + */ + DeviceContext(); + + /** + * @brief Copy construct. + */ + DeviceContext(const DeviceContext&); + + /** + * @brief Move construct. + */ + DeviceContext(DeviceContext&&); + + /** + * @brief Default destruct. + */ + virtual ~DeviceContext(); + + /** + * @brief Set the deveice-releated Allocator object. + * + * @param allocator + */ + void SetAllocator(Allocator*); + + /** + * @brief Get the const Allocator object. + * + * @return Allocator + */ + const Allocator& GetAllocator() const; + + /** + * @brief Allocate memory for tensor. + */ + void Alloc(pten::TensorBase*); + + // TODO(wilber): Just for the convenience of migrating the code, it will be + // modified or removed later. + virtual Place GetPlace() const = 0; + // TODO(wilber): The fluid framework uses wait() in many places, how to delete + // this API interface. + virtual void Wait() const {} + + private: + struct Impl; + std::unique_ptr impl_; +}; + +} // namespace pten diff --git a/paddle/pten/tests/core/CMakeLists.txt b/paddle/pten/tests/core/CMakeLists.txt index 2d4ee7f6d6a..117d6a29252 100644 --- a/paddle/pten/tests/core/CMakeLists.txt +++ b/paddle/pten/tests/core/CMakeLists.txt @@ -3,3 +3,4 @@ cc_test(test_intrusive_ptr SRCS test_intrusive_ptr.cc) cc_test(test_type_info SRCS test_type_info.cc) cc_test(test_convert_utils SRCS test_convert_utils.cc DEPS convert_utils) cc_test(test_kernel_factory SRCS test_kernel_factory.cc DEPS kernel_factory scale_kernel) +cc_test(test_pten_device_context SRCS test_device_context.cc DEPS pten_context cpu_context) diff --git a/paddle/pten/tests/core/test_device_context.cc b/paddle/pten/tests/core/test_device_context.cc new file mode 100644 index 00000000000..a44d0d32156 --- /dev/null +++ b/paddle/pten/tests/core/test_device_context.cc @@ -0,0 +1,68 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "gtest/gtest.h" + +// TODO(wilber): will remove after the cpu, gpu context megre. +#include "paddle/pten/backends/cpu/cpu_context.h" +// #include "paddle/pten/backends/all_context.h" + +// NOTE: The paddle framework should add WITH_EIGEN option to support compile +// without eigen. +#include "unsupported/Eigen/CXX11/Tensor" + +namespace pten { +namespace tests { + +TEST(DeviceContext, cpu_context) { + std::cout << "test training scenarios" << std::endl; + { + pten::CPUContext ctx; + CHECK(ctx.eigen_device() != nullptr); + } + + std::cout << "test inference scenarios" << std::endl; + Eigen::DefaultDevice* device = new Eigen::DefaultDevice(); + { + pten::CPUContextResource ctx_res{device}; + pten::CPUContext ctx(ctx_res); + CHECK(ctx.eigen_device() != nullptr); + } + { + pten::CPUContextResource ctx_res{nullptr}; + pten::CPUContext ctx(ctx_res); + ctx.SetEigenDevice(device); + CHECK(ctx.eigen_device() != nullptr); + } + delete device; + + std::cout << "test copy constructor" << std::endl; + { + pten::CPUContext ctx1; + pten::CPUContext ctx2(ctx1); + CHECK_EQ(ctx1.eigen_device(), ctx2.eigen_device()); + } + + std::cout << "test move constructor" << std::endl; + { + pten::CPUContext ctx1 = pten::CPUContext(); + auto* eigen_device1 = ctx1.eigen_device(); + pten::CPUContext ctx2(std::move(ctx1)); + auto* eigen_device2 = ctx2.eigen_device(); + CHECK_EQ(eigen_device1, eigen_device2); + } +} + +} // namespace tests +} // namespace pten diff --git a/paddle/pten/tests/kernels/test_cast_dev_api.cc b/paddle/pten/tests/kernels/test_cast_dev_api.cc index 90624adeb34..80328d0b243 100644 --- a/paddle/pten/tests/kernels/test_cast_dev_api.cc +++ b/paddle/pten/tests/kernels/test_cast_dev_api.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include #include +#include "paddle/pten/backends/cpu/cpu_context.h" #include "paddle/pten/kernels/cast_kernel.h" #include "paddle/pten/api/lib/utils/allocator.h" @@ -44,16 +45,11 @@ TEST(DEV_API, cast) { dense_x_data[i] = i * 1.0; sum += i * 1.0; } - paddle::platform::DeviceContextPool& pool = - paddle::platform::DeviceContextPool::Instance(); - auto* dev_ctx = pool.Get(paddle::platform::CPUPlace()); + pten::CPUContext dev_ctx; pten::DataType out_dtype = pten::DataType::FLOAT64; // 2. test API - auto out = pten::Cast( - *(static_cast(dev_ctx)), - dense_x, - out_dtype); + auto out = pten::Cast(dev_ctx, dense_x, out_dtype); // 3. check result ASSERT_EQ(out.dims().size(), 2); diff --git a/paddle/pten/tests/kernels/test_conj_dev_api.cc b/paddle/pten/tests/kernels/test_conj_dev_api.cc index 789d55491f3..6f2ea0602b8 100644 --- a/paddle/pten/tests/kernels/test_conj_dev_api.cc +++ b/paddle/pten/tests/kernels/test_conj_dev_api.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include #include +#include "paddle/pten/backends/cpu/cpu_context.h" #include "paddle/pten/kernels/complex_kernel.h" #include "paddle/pten/api/lib/utils/allocator.h" @@ -41,13 +42,10 @@ TEST(DEV_API, conj) { dense_x_data[i] = paddle::complex64(i * 1.0, i * 1.0); } - paddle::platform::DeviceContextPool& pool = - paddle::platform::DeviceContextPool::Instance(); - auto* dev_ctx = pool.Get(paddle::platform::CPUPlace()); + pten::CPUContext dev_ctx; // 2. test API - auto out = pten::Conj( - *(static_cast(dev_ctx)), dense_x); + auto out = pten::Conj(dev_ctx, dense_x); // 3. check result ASSERT_EQ(out.dims().size(), 2); diff --git a/paddle/pten/tests/kernels/test_copy_dev_api.cc b/paddle/pten/tests/kernels/test_copy_dev_api.cc index c4d8c37eb9e..d690b29d71f 100644 --- a/paddle/pten/tests/kernels/test_copy_dev_api.cc +++ b/paddle/pten/tests/kernels/test_copy_dev_api.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include #include +#include "paddle/pten/backends/cpu/cpu_context.h" #include "paddle/pten/core/kernel_registry.h" #include "paddle/pten/kernels/copy_kernel.h" @@ -54,9 +55,8 @@ TEST(DEV_API, copy) { const auto& a = paddle::platform::CPUPlace(); std::cout << typeid(a).name() << std::endl; // 2. test API - auto& pool = paddle::platform::DeviceContextPool::Instance(); - auto* dev_ctx = pool.GetByPlace(paddle::platform::CPUPlace()); - pten::Copy(*dev_ctx, *(dense_src.get()), false, dense_dst.get()); + pten::CPUContext dev_ctx; + pten::Copy(dev_ctx, *(dense_src.get()), false, dense_dst.get()); // 3. check result for (int64_t i = 0; i < dense_src->numel(); i++) { diff --git a/paddle/pten/tests/kernels/test_creation_dev_api.cc b/paddle/pten/tests/kernels/test_creation_dev_api.cc index 169a77cf343..b1c23d4a768 100644 --- a/paddle/pten/tests/kernels/test_creation_dev_api.cc +++ b/paddle/pten/tests/kernels/test_creation_dev_api.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include #include +#include "paddle/pten/backends/cpu/cpu_context.h" #include "paddle/pten/kernels/empty_kernel.h" #include "paddle/pten/kernels/full_kernel.h" @@ -30,15 +31,10 @@ using DDim = paddle::framework::DDim; TEST(DEV_API, empty) { // 1. create input - paddle::platform::DeviceContextPool& pool = - paddle::platform::DeviceContextPool::Instance(); - auto* dev_ctx = pool.Get(paddle::platform::CPUPlace()); + pten::CPUContext dev_ctx; // 2. test API - auto out = pten::Empty( - *(static_cast(dev_ctx)), - {3, 2}, - pten::DataType::INT32); + auto out = pten::Empty(dev_ctx, {3, 2}, pten::DataType::INT32); // 3. check result ASSERT_EQ(out.dims().size(), 2); @@ -59,13 +55,9 @@ TEST(DEV_API, empty_like) { auto* dense_x_data = dense_x.mutable_data(); dense_x_data[0] = 0; - paddle::platform::DeviceContextPool& pool = - paddle::platform::DeviceContextPool::Instance(); - auto* dev_ctx = pool.Get(paddle::platform::CPUPlace()); - // 2. test API - auto out = pten::EmptyLike( - *(static_cast(dev_ctx)), dense_x); + pten::CPUContext dev_ctx; + auto out = pten::EmptyLike(dev_ctx, dense_x); // 3. check result ASSERT_EQ(out.dims().size(), 2); @@ -79,16 +71,9 @@ TEST(DEV_API, full) { // 1. create input float val = 1.0; - paddle::platform::DeviceContextPool& pool = - paddle::platform::DeviceContextPool::Instance(); - auto* dev_ctx = pool.Get(paddle::platform::CPUPlace()); - // 2. test API - auto out = pten::Full( - *(static_cast(dev_ctx)), - {3, 2}, - val, - pten::DataType::FLOAT32); + pten::CPUContext dev_ctx; + auto out = pten::Full(dev_ctx, {3, 2}, val, pten::DataType::FLOAT32); // 3. check result ASSERT_EQ(out.dims().size(), 2); @@ -115,15 +100,10 @@ TEST(DEV_API, full_like) { dense_x_data[0] = 0; float val = 1.0; - paddle::platform::DeviceContextPool& pool = - paddle::platform::DeviceContextPool::Instance(); - auto* dev_ctx = pool.Get(paddle::platform::CPUPlace()); + pten::CPUContext dev_ctx; // 2. test API - auto out = pten::FullLike( - *(static_cast(dev_ctx)), - dense_x, - val); + auto out = pten::FullLike(dev_ctx, dense_x, val); // 3. check result ASSERT_EQ(out.dims().size(), 2); diff --git a/paddle/pten/tests/kernels/test_dot_dev_api.cc b/paddle/pten/tests/kernels/test_dot_dev_api.cc index a5773b8aa96..4213240f57b 100644 --- a/paddle/pten/tests/kernels/test_dot_dev_api.cc +++ b/paddle/pten/tests/kernels/test_dot_dev_api.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include #include +#include "paddle/pten/backends/cpu/cpu_context.h" #include "paddle/pten/kernels/dot_kernel.h" #include "paddle/pten/api/lib/utils/allocator.h" @@ -52,15 +53,9 @@ TEST(DEV_API, dot) { } } - paddle::platform::DeviceContextPool& pool = - paddle::platform::DeviceContextPool::Instance(); - auto* dev_ctx = pool.Get(paddle::platform::CPUPlace()); - // 2. test API - auto out = pten::Dot( - *(static_cast(dev_ctx)), - dense_x, - dense_y); + pten::CPUContext dev_ctx; + auto out = pten::Dot(dev_ctx, dense_x, dense_y); // 3. check result ASSERT_EQ(out.dims().size(), 2); diff --git a/paddle/pten/tests/kernels/test_elementwise_dev_api.cc b/paddle/pten/tests/kernels/test_elementwise_dev_api.cc index 40998a8d57c..23583a84356 100644 --- a/paddle/pten/tests/kernels/test_elementwise_dev_api.cc +++ b/paddle/pten/tests/kernels/test_elementwise_dev_api.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include #include +#include "paddle/pten/backends/cpu/cpu_context.h" #include "paddle/pten/kernels/math_kernel.h" #include "paddle/pten/api/lib/utils/allocator.h" @@ -54,16 +55,10 @@ TEST(DEV_API, add) { dense_y_data[i] = i * 2.0; } int axis = 1; - paddle::platform::DeviceContextPool& pool = - paddle::platform::DeviceContextPool::Instance(); - auto* dev_ctx = pool.Get(paddle::platform::CPUPlace()); // 2. test API - auto dense_out = pten::Add( - *(static_cast(dev_ctx)), - dense_x, - dense_y, - axis); + pten::CPUContext dev_ctx; + auto dense_out = pten::Add(dev_ctx, dense_x, dense_y, axis); // 3. check result ASSERT_EQ(dense_out.dims().size(), 2); @@ -107,16 +102,10 @@ TEST(DEV_API, subtract) { dense_y_data[i] = i * 2.0; } int axis = 1; - paddle::platform::DeviceContextPool& pool = - paddle::platform::DeviceContextPool::Instance(); - auto* dev_ctx = pool.Get(paddle::platform::CPUPlace()); // 2. test API - auto dense_out = pten::Subtract( - *(static_cast(dev_ctx)), - dense_x, - dense_y, - axis); + pten::CPUContext dev_ctx; + auto dense_out = pten::Subtract(dev_ctx, dense_x, dense_y, axis); // 3. check result ASSERT_EQ(dense_out.dims().size(), 2); @@ -160,16 +149,10 @@ TEST(DEV_API, divide) { dense_y_data[i] = i * 2.0 + 1; } int axis = 1; - paddle::platform::DeviceContextPool& pool = - paddle::platform::DeviceContextPool::Instance(); - auto* dev_ctx = pool.Get(paddle::platform::CPUPlace()); // 2. test API - auto dense_out = pten::Divide( - *(static_cast(dev_ctx)), - dense_x, - dense_y, - axis); + pten::CPUContext dev_ctx; + auto dense_out = pten::Divide(dev_ctx, dense_x, dense_y, axis); // 3. check result ASSERT_EQ(dense_out.dims().size(), 2); @@ -213,16 +196,10 @@ TEST(DEV_API, multiply) { dense_y_data[i] = i * 2.0; } int axis = 1; - paddle::platform::DeviceContextPool& pool = - paddle::platform::DeviceContextPool::Instance(); - auto* dev_ctx = pool.Get(paddle::platform::CPUPlace()); // 2. test API - auto dense_out = pten::Multiply( - *(static_cast(dev_ctx)), - dense_x, - dense_y, - axis); + pten::CPUContext dev_ctx; + auto dense_out = pten::Multiply(dev_ctx, dense_x, dense_y, axis); // 3. check result ASSERT_EQ(dense_out.dims().size(), 2); diff --git a/paddle/pten/tests/kernels/test_flatten_dev_api.cc b/paddle/pten/tests/kernels/test_flatten_dev_api.cc index d66ff468fcf..13fc327b669 100644 --- a/paddle/pten/tests/kernels/test_flatten_dev_api.cc +++ b/paddle/pten/tests/kernels/test_flatten_dev_api.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include #include +#include "paddle/pten/backends/cpu/cpu_context.h" #include "paddle/pten/kernels/flatten_kernel.h" #include "paddle/pten/api/lib/utils/allocator.h" @@ -52,16 +53,10 @@ TEST(DEV_API, flatten) { dense_x_data[i] = i; } int start_axis = 1, stop_axis = 2; - paddle::platform::DeviceContextPool& pool = - paddle::platform::DeviceContextPool::Instance(); - auto* dev_ctx = pool.Get(paddle::platform::CPUPlace()); + pten::CPUContext dev_ctx; // 2. test API - auto out = pten::Flatten( - *(static_cast(dev_ctx)), - dense_x, - start_axis, - stop_axis); + auto out = pten::Flatten(dev_ctx, dense_x, start_axis, stop_axis); // 3. check result std::vector expect_shape = {3, 4, 3}; diff --git a/paddle/pten/tests/kernels/test_matmul_dev_api.cc b/paddle/pten/tests/kernels/test_matmul_dev_api.cc index 0c1338f1955..118215db505 100644 --- a/paddle/pten/tests/kernels/test_matmul_dev_api.cc +++ b/paddle/pten/tests/kernels/test_matmul_dev_api.cc @@ -50,13 +50,9 @@ TEST(DEV_API, dot) { } std::vector sum(9, 6.0); - paddle::platform::DeviceContextPool& pool = - paddle::platform::DeviceContextPool::Instance(); - auto* ctx = pool.Get(paddle::platform::CPUPlace()); - // 2. test API - auto out = Matmul( - *(static_cast(ctx)), dense_x, dense_y, false, false); + pten::CPUContext dev_ctx; + auto out = Matmul(dev_ctx, dense_x, dense_y, false, false); // 3. check result ASSERT_EQ(out.dims().size(), 2); diff --git a/paddle/pten/tests/kernels/test_mean_dev_api.cc b/paddle/pten/tests/kernels/test_mean_dev_api.cc index 98782fd5dae..a8860540fd0 100644 --- a/paddle/pten/tests/kernels/test_mean_dev_api.cc +++ b/paddle/pten/tests/kernels/test_mean_dev_api.cc @@ -42,17 +42,11 @@ TEST(DEV_API, mean) { dense_x_data[i] = i * 1.0; sum += i * 1.0; } - paddle::platform::DeviceContextPool& pool = - paddle::platform::DeviceContextPool::Instance(); - auto* dev_ctx = pool.Get(paddle::platform::CPUPlace()); - std::vector dims = {0, 1}; + // 2. test API - auto out = pten::Mean( - *(static_cast(dev_ctx)), - dense_x, - dims, - false); + pten::CPUContext dev_ctx; + auto out = pten::Mean(dev_ctx, dense_x, dims, false); // 3. check result ASSERT_EQ(out.dims().size(), 1); diff --git a/paddle/pten/tests/kernels/test_reshape_dev_api.cc b/paddle/pten/tests/kernels/test_reshape_dev_api.cc index 02139d02de1..52038593d70 100644 --- a/paddle/pten/tests/kernels/test_reshape_dev_api.cc +++ b/paddle/pten/tests/kernels/test_reshape_dev_api.cc @@ -42,16 +42,11 @@ TEST(DEV_API, reshape) { for (int i = 0; i < dense_x.numel(); i++) { dense_x_data[i] = i; } - paddle::platform::DeviceContextPool& pool = - paddle::platform::DeviceContextPool::Instance(); - auto* dev_ctx = pool.Get(paddle::platform::CPUPlace()); std::vector shape{12, 3}; // 2. test API - auto out = pten::Reshape( - *(static_cast(dev_ctx)), - dense_x, - shape); + pten::CPUContext dev_ctx; + auto out = pten::Reshape(dev_ctx, dense_x, shape); // 3. check result std::vector expect_shape = {12, 3}; ASSERT_EQ(out.dims()[0], expect_shape[0]); diff --git a/paddle/pten/tests/kernels/test_scale_dev_api.cc b/paddle/pten/tests/kernels/test_scale_dev_api.cc index 02f324deb4c..1c0be6c06aa 100644 --- a/paddle/pten/tests/kernels/test_scale_dev_api.cc +++ b/paddle/pten/tests/kernels/test_scale_dev_api.cc @@ -44,17 +44,10 @@ TEST(DEV_API, scale) { float bias = 1; bool bias_after_scale = true; - paddle::platform::DeviceContextPool& pool = - paddle::platform::DeviceContextPool::Instance(); - auto* dev_ctx = pool.Get(paddle::platform::CPUPlace()); - // 2. test API - auto out = pten::Scale( - *(static_cast(dev_ctx)), - dense_x, - scale, - bias, - bias_after_scale); + pten::CPUContext dev_ctx; + auto out = + pten::Scale(dev_ctx, dense_x, scale, bias, bias_after_scale); // 3. check result ASSERT_EQ(out.dims().size(), 2); @@ -88,17 +81,10 @@ TEST(DEV_API, scale_host) { float bias = 1; bool bias_after_scale = true; - paddle::platform::DeviceContextPool& pool = - paddle::platform::DeviceContextPool::Instance(); - auto* dev_ctx = pool.Get(paddle::platform::CPUPlace()); - // 2. test API - auto out = pten::Scale( - *(static_cast(dev_ctx)), - dense_x, - scale, - bias, - bias_after_scale); + pten::CPUContext dev_ctx; + auto out = + pten::Scale(dev_ctx, dense_x, scale, bias, bias_after_scale); // 3. check result ASSERT_EQ(out.dims().size(), 2); diff --git a/paddle/pten/tests/kernels/test_sum_dev_api.cc b/paddle/pten/tests/kernels/test_sum_dev_api.cc index 312a6ce6100..2b11ba9595c 100644 --- a/paddle/pten/tests/kernels/test_sum_dev_api.cc +++ b/paddle/pten/tests/kernels/test_sum_dev_api.cc @@ -42,18 +42,12 @@ TEST(DEV_API, sum) { dense_x_data[i] = i * 1.0; sum += i * 1.0; } - paddle::platform::DeviceContextPool& pool = - paddle::platform::DeviceContextPool::Instance(); - auto* dev_ctx = pool.Get(paddle::platform::CPUPlace()); std::vector axis = {0, 1}; + pten::CPUContext dev_ctx; // 2. test API - auto out = pten::Sum( - *(static_cast(dev_ctx)), - dense_x, - axis, - pten::DataType::FLOAT32, - false); + auto out = + pten::Sum(dev_ctx, dense_x, axis, pten::DataType::FLOAT32, false); // 3. check result ASSERT_EQ(out.dims().size(), 1); -- GitLab