未验证 提交 7d727f36 编写于 作者: H Huang Jiyi 提交者: GitHub

[phi decopuling] decouple dependency to device_context in phi (Part 3) (#51559)

* remove device_context include

* fix bug

* fix bug
上级 e79699fb
...@@ -103,7 +103,8 @@ static void SplitTensorsForAllReduce( ...@@ -103,7 +103,8 @@ static void SplitTensorsForAllReduce(
} }
// Sometimes direct copies will be faster // Sometimes direct copies will be faster
if (p_dense_tensors->size() < 10) { if (p_dense_tensors->size() < 10) {
phi::funcs::StridedMemcpyWithAxis0<T>(context, *in, shape_refer, &outs); phi::funcs::StridedMemcpyWithAxis0<T, DeviceContext>(
context, *in, shape_refer, &outs);
} else { } else {
operators::math::SplitFunctor<DeviceContext, T> split_functor_; operators::math::SplitFunctor<DeviceContext, T> split_functor_;
split_functor_(context, *in, shape_refer, 0, &outs); split_functor_(context, *in, shape_refer, 0, &outs);
......
...@@ -727,13 +727,14 @@ void _concatCompute(const std::vector<phi::DenseTensor> &ins, ...@@ -727,13 +727,14 @@ void _concatCompute(const std::vector<phi::DenseTensor> &ins,
for (auto &in : ins) { for (auto &in : ins) {
auto in_stride = phi::stride_numel(in.dims()); auto in_stride = phi::stride_numel(in.dims());
auto out_stride = phi::stride_numel(out->dims()); auto out_stride = phi::stride_numel(out->dims());
phi::funcs::StridedNumelCopyWithAxis<T>(ctx, phi::funcs::StridedNumelCopyWithAxis<T, phi::CPUContext>(
axis, ctx,
out->data<T>() + output_offset, axis,
out_stride, out->data<T>() + output_offset,
in.data<T>(), out_stride,
in_stride, in.data<T>(),
in_stride[axis]); in_stride,
in_stride[axis]);
output_offset += in_stride[axis]; output_offset += in_stride[axis];
} }
} else { } else {
......
...@@ -86,13 +86,14 @@ void ConcatKernel(const Context& dev_ctx, ...@@ -86,13 +86,14 @@ void ConcatKernel(const Context& dev_ctx,
} }
auto in_stride = phi::stride_numel(in->dims()); auto in_stride = phi::stride_numel(in->dims());
auto out_stride = phi::stride_numel(out->dims()); auto out_stride = phi::stride_numel(out->dims());
phi::funcs::StridedNumelCopyWithAxis<T>(dev_ctx, phi::funcs::StridedNumelCopyWithAxis<T, Context>(
axis, dev_ctx,
out->data<T>() + output_offset, axis,
out_stride, out->data<T>() + output_offset,
in->data<T>(), out_stride,
in_stride, in->data<T>(),
in_stride[axis]); in_stride,
in_stride[axis]);
output_offset += in_stride[axis]; output_offset += in_stride[axis];
} }
} else { } else {
......
...@@ -27,7 +27,6 @@ limitations under the License. */ ...@@ -27,7 +27,6 @@ limitations under the License. */
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
......
...@@ -14,7 +14,6 @@ limitations under the License. */ ...@@ -14,7 +14,6 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
......
...@@ -14,9 +14,12 @@ limitations under the License. */ ...@@ -14,9 +14,12 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/detail/strided_memcpy.h" #include "paddle/phi/kernels/funcs/detail/strided_memcpy.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
namespace phi {
class CPUContext;
} // namespace phi
namespace phi { namespace phi {
namespace funcs { namespace funcs {
...@@ -46,6 +49,32 @@ inline void StridedMemcpy(const phi::DeviceContext& dev_ctx, ...@@ -46,6 +49,32 @@ inline void StridedMemcpy(const phi::DeviceContext& dev_ctx,
dst_dim.apply_visitor(func); dst_dim.apply_visitor(func);
} }
template <typename Context>
inline void CopyWithContext(const Context& ctx,
const Place& dst_place,
void* dst,
const Place& src_place,
const void* src,
size_t num) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \
defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_MLU)
memory_utils::Copy(dst_place, dst, src_place, src, num, ctx.stream());
#else
PADDLE_THROW(
phi::errors::PreconditionNotMet("Paddle is not compiled with GPU."));
#endif
}
template <>
inline void CopyWithContext<phi::CPUContext>(const phi::CPUContext& ctx,
const Place& dst_place,
void* dst,
const Place& src_place,
const void* src,
size_t num) {
memory_utils::Copy(dst_place, dst, src_place, src, num);
}
// Strided numel memory copy from src to dst by the specified axis // Strided numel memory copy from src to dst by the specified axis
// //
// For example, for a tensor dims [4, 20, 100], the strieded numel is // For example, for a tensor dims [4, 20, 100], the strieded numel is
...@@ -53,8 +82,8 @@ inline void StridedMemcpy(const phi::DeviceContext& dev_ctx, ...@@ -53,8 +82,8 @@ inline void StridedMemcpy(const phi::DeviceContext& dev_ctx,
// //
// NOTE: The src and dst tensor should have the same elements // NOTE: The src and dst tensor should have the same elements
// except the specified axis. // except the specified axis.
template <typename T> template <typename T, typename Context>
inline void StridedNumelCopyWithAxis(const phi::DeviceContext& ctx, inline void StridedNumelCopyWithAxis(const Context& ctx,
int64_t axis, int64_t axis,
T* dst, T* dst,
const phi::DDim& dst_stride_numel, const phi::DDim& dst_stride_numel,
...@@ -102,52 +131,18 @@ inline void StridedNumelCopyWithAxis(const phi::DeviceContext& ctx, ...@@ -102,52 +131,18 @@ inline void StridedNumelCopyWithAxis(const phi::DeviceContext& ctx,
} }
for (int64_t i = 0; i < before; ++i) { for (int64_t i = 0; i < before; ++i) {
if (place.GetType() == phi::AllocationType::CPU) { CopyWithContext<Context>(ctx,
auto& cpu_place = place; place,
memory_utils::Copy(cpu_place, dst + i * dst_after,
dst + i * dst_after, place,
cpu_place, src + i * src_after,
src + i * src_after, sizeof(T) * size);
sizeof(T) * size);
} else {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto& gpu_place = place;
auto& cuda_ctx = reinterpret_cast<const phi::GPUContext&>(ctx);
memory_utils::Copy(gpu_place,
dst + i * dst_after,
gpu_place,
src + i * src_after,
sizeof(T) * size,
cuda_ctx.stream());
#elif defined(PADDLE_WITH_ASCEND_CL)
auto& npu_place = place;
auto& npu_ctx = reinterpret_cast<const platform::NPUDeviceContext&>(ctx);
memory_utils::Copy(npu_place,
dst + i * dst_after,
npu_place,
src + i * src_after,
sizeof(T) * size,
npu_ctx.stream());
#elif defined(PADDLE_WITH_MLU)
auto& mlu_place = place;
auto& mlu_ctx = reinterpret_cast<const platform::MLUDeviceContext&>(ctx);
memory_utils::Copy(mlu_place,
dst + i * dst_after,
mlu_place,
src + i * src_after,
sizeof(T) * size,
mlu_ctx.stream());
#else
PADDLE_THROW(
phi::errors::PreconditionNotMet("Paddle is not compiled with GPU."));
#endif
}
} }
} }
template <typename T> template <typename T, typename Context>
inline void StridedMemcpyWithAxis0( inline void StridedMemcpyWithAxis0(
const phi::DeviceContext& dev_ctx, const Context& dev_ctx,
const phi::DenseTensor& input, const phi::DenseTensor& input,
const std::vector<const phi::DenseTensor*>& shape_refer, const std::vector<const phi::DenseTensor*>& shape_refer,
std::vector<phi::DenseTensor*>* outputs) { std::vector<phi::DenseTensor*>* outputs) {
...@@ -159,13 +154,13 @@ inline void StridedMemcpyWithAxis0( ...@@ -159,13 +154,13 @@ inline void StridedMemcpyWithAxis0(
auto out_stride = stride_numel(shape_refer[i]->dims()); auto out_stride = stride_numel(shape_refer[i]->dims());
auto out = outputs->at(i); auto out = outputs->at(i);
if (out != nullptr && out->initialized() && out->numel() > 0) { if (out != nullptr && out->initialized() && out->numel() > 0) {
StridedNumelCopyWithAxis<T>(dev_ctx, StridedNumelCopyWithAxis<T, Context>(dev_ctx,
axis, axis,
out->data<T>(), out->data<T>(),
out_stride, out_stride,
input.data<T>() + input_offset, input.data<T>() + input_offset,
in_stride, in_stride,
out_stride[axis]); out_stride[axis]);
} }
input_offset += out_stride[axis]; input_offset += out_stride[axis];
} }
......
...@@ -85,13 +85,14 @@ void ConcatKernel(const Context& dev_ctx, ...@@ -85,13 +85,14 @@ void ConcatKernel(const Context& dev_ctx,
} }
auto in_stride = phi::stride_numel(in->dims()); auto in_stride = phi::stride_numel(in->dims());
auto out_stride = phi::stride_numel(out->dims()); auto out_stride = phi::stride_numel(out->dims());
phi::funcs::StridedNumelCopyWithAxis<T>(dev_ctx, phi::funcs::StridedNumelCopyWithAxis<T, Context>(
axis, dev_ctx,
out->data<T>() + output_offset, axis,
out_stride, out->data<T>() + output_offset,
in->data<T>(), out_stride,
in_stride, in->data<T>(),
in_stride[axis]); in_stride,
in_stride[axis]);
output_offset += in_stride[axis]; output_offset += in_stride[axis];
} }
} else { } else {
......
...@@ -57,7 +57,7 @@ void ConcatGradKernel(const Context& dev_ctx, ...@@ -57,7 +57,7 @@ void ConcatGradKernel(const Context& dev_ctx,
if (axis == 0 && outs.size() < 10) { if (axis == 0 && outs.size() < 10) {
std::vector<const DenseTensor*> ref_shape; std::vector<const DenseTensor*> ref_shape;
ref_shape.insert(ref_shape.begin(), x.begin(), x.end()); ref_shape.insert(ref_shape.begin(), x.begin(), x.end());
phi::funcs::StridedMemcpyWithAxis0<T>( phi::funcs::StridedMemcpyWithAxis0<T, Context>(
dev_ctx, out_grad, ref_shape, &outputs); dev_ctx, out_grad, ref_shape, &outputs);
} else { } else {
phi::funcs::SplitFunctor<Context, T> split_functor; phi::funcs::SplitFunctor<Context, T> split_functor;
......
...@@ -37,7 +37,8 @@ void SplitKernel(const Context& dev_ctx, ...@@ -37,7 +37,8 @@ void SplitKernel(const Context& dev_ctx,
int axis = axis_scalar.to<int>(); int axis = axis_scalar.to<int>();
// Sometimes direct copies will be faster, this maybe need deeply analysis. // Sometimes direct copies will be faster, this maybe need deeply analysis.
if (axis == 0 && outs.size() < 10) { if (axis == 0 && outs.size() < 10) {
phi::funcs::StridedMemcpyWithAxis0<T>(dev_ctx, x, shape_refer, &outs); phi::funcs::StridedMemcpyWithAxis0<T, Context>(
dev_ctx, x, shape_refer, &outs);
} else { } else {
phi::funcs::SplitFunctor<Context, T> functor; phi::funcs::SplitFunctor<Context, T> functor;
functor(dev_ctx, x, shape_refer, axis, &outs); functor(dev_ctx, x, shape_refer, axis, &outs);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册