未验证 提交 7d727f36 编写于 作者: H Huang Jiyi 提交者: GitHub

[phi decopuling] decouple dependency to device_context in phi (Part 3) (#51559)

* remove device_context include

* fix bug

* fix bug
上级 e79699fb
......@@ -103,7 +103,8 @@ static void SplitTensorsForAllReduce(
}
// Sometimes direct copies will be faster
if (p_dense_tensors->size() < 10) {
phi::funcs::StridedMemcpyWithAxis0<T>(context, *in, shape_refer, &outs);
phi::funcs::StridedMemcpyWithAxis0<T, DeviceContext>(
context, *in, shape_refer, &outs);
} else {
operators::math::SplitFunctor<DeviceContext, T> split_functor_;
split_functor_(context, *in, shape_refer, 0, &outs);
......
......@@ -727,7 +727,8 @@ void _concatCompute(const std::vector<phi::DenseTensor> &ins,
for (auto &in : ins) {
auto in_stride = phi::stride_numel(in.dims());
auto out_stride = phi::stride_numel(out->dims());
phi::funcs::StridedNumelCopyWithAxis<T>(ctx,
phi::funcs::StridedNumelCopyWithAxis<T, phi::CPUContext>(
ctx,
axis,
out->data<T>() + output_offset,
out_stride,
......
......@@ -86,7 +86,8 @@ void ConcatKernel(const Context& dev_ctx,
}
auto in_stride = phi::stride_numel(in->dims());
auto out_stride = phi::stride_numel(out->dims());
phi::funcs::StridedNumelCopyWithAxis<T>(dev_ctx,
phi::funcs::StridedNumelCopyWithAxis<T, Context>(
dev_ctx,
axis,
out->data<T>() + output_offset,
out_stride,
......
......@@ -27,7 +27,6 @@ limitations under the License. */
#include <utility>
#include <vector>
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/data_type.h"
......
......@@ -14,7 +14,6 @@ limitations under the License. */
#include <algorithm>
#include <vector>
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/data_type.h"
......
......@@ -14,9 +14,12 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/detail/strided_memcpy.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
class CPUContext;
} // namespace phi
namespace phi {
namespace funcs {
......@@ -46,6 +49,32 @@ inline void StridedMemcpy(const phi::DeviceContext& dev_ctx,
dst_dim.apply_visitor(func);
}
template <typename Context>
inline void CopyWithContext(const Context& ctx,
const Place& dst_place,
void* dst,
const Place& src_place,
const void* src,
size_t num) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \
defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_MLU)
memory_utils::Copy(dst_place, dst, src_place, src, num, ctx.stream());
#else
PADDLE_THROW(
phi::errors::PreconditionNotMet("Paddle is not compiled with GPU."));
#endif
}
template <>
inline void CopyWithContext<phi::CPUContext>(const phi::CPUContext& ctx,
const Place& dst_place,
void* dst,
const Place& src_place,
const void* src,
size_t num) {
memory_utils::Copy(dst_place, dst, src_place, src, num);
}
// Strided numel memory copy from src to dst by the specified axis
//
// For example, for a tensor dims [4, 20, 100], the strieded numel is
......@@ -53,8 +82,8 @@ inline void StridedMemcpy(const phi::DeviceContext& dev_ctx,
//
// NOTE: The src and dst tensor should have the same elements
// except the specified axis.
template <typename T>
inline void StridedNumelCopyWithAxis(const phi::DeviceContext& ctx,
template <typename T, typename Context>
inline void StridedNumelCopyWithAxis(const Context& ctx,
int64_t axis,
T* dst,
const phi::DDim& dst_stride_numel,
......@@ -102,52 +131,18 @@ inline void StridedNumelCopyWithAxis(const phi::DeviceContext& ctx,
}
for (int64_t i = 0; i < before; ++i) {
if (place.GetType() == phi::AllocationType::CPU) {
auto& cpu_place = place;
memory_utils::Copy(cpu_place,
CopyWithContext<Context>(ctx,
place,
dst + i * dst_after,
cpu_place,
place,
src + i * src_after,
sizeof(T) * size);
} else {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto& gpu_place = place;
auto& cuda_ctx = reinterpret_cast<const phi::GPUContext&>(ctx);
memory_utils::Copy(gpu_place,
dst + i * dst_after,
gpu_place,
src + i * src_after,
sizeof(T) * size,
cuda_ctx.stream());
#elif defined(PADDLE_WITH_ASCEND_CL)
auto& npu_place = place;
auto& npu_ctx = reinterpret_cast<const platform::NPUDeviceContext&>(ctx);
memory_utils::Copy(npu_place,
dst + i * dst_after,
npu_place,
src + i * src_after,
sizeof(T) * size,
npu_ctx.stream());
#elif defined(PADDLE_WITH_MLU)
auto& mlu_place = place;
auto& mlu_ctx = reinterpret_cast<const platform::MLUDeviceContext&>(ctx);
memory_utils::Copy(mlu_place,
dst + i * dst_after,
mlu_place,
src + i * src_after,
sizeof(T) * size,
mlu_ctx.stream());
#else
PADDLE_THROW(
phi::errors::PreconditionNotMet("Paddle is not compiled with GPU."));
#endif
}
}
}
template <typename T>
template <typename T, typename Context>
inline void StridedMemcpyWithAxis0(
const phi::DeviceContext& dev_ctx,
const Context& dev_ctx,
const phi::DenseTensor& input,
const std::vector<const phi::DenseTensor*>& shape_refer,
std::vector<phi::DenseTensor*>* outputs) {
......@@ -159,7 +154,7 @@ inline void StridedMemcpyWithAxis0(
auto out_stride = stride_numel(shape_refer[i]->dims());
auto out = outputs->at(i);
if (out != nullptr && out->initialized() && out->numel() > 0) {
StridedNumelCopyWithAxis<T>(dev_ctx,
StridedNumelCopyWithAxis<T, Context>(dev_ctx,
axis,
out->data<T>(),
out_stride,
......
......@@ -85,7 +85,8 @@ void ConcatKernel(const Context& dev_ctx,
}
auto in_stride = phi::stride_numel(in->dims());
auto out_stride = phi::stride_numel(out->dims());
phi::funcs::StridedNumelCopyWithAxis<T>(dev_ctx,
phi::funcs::StridedNumelCopyWithAxis<T, Context>(
dev_ctx,
axis,
out->data<T>() + output_offset,
out_stride,
......
......@@ -57,7 +57,7 @@ void ConcatGradKernel(const Context& dev_ctx,
if (axis == 0 && outs.size() < 10) {
std::vector<const DenseTensor*> ref_shape;
ref_shape.insert(ref_shape.begin(), x.begin(), x.end());
phi::funcs::StridedMemcpyWithAxis0<T>(
phi::funcs::StridedMemcpyWithAxis0<T, Context>(
dev_ctx, out_grad, ref_shape, &outputs);
} else {
phi::funcs::SplitFunctor<Context, T> split_functor;
......
......@@ -37,7 +37,8 @@ void SplitKernel(const Context& dev_ctx,
int axis = axis_scalar.to<int>();
// Sometimes direct copies will be faster, this maybe need deeply analysis.
if (axis == 0 && outs.size() < 10) {
phi::funcs::StridedMemcpyWithAxis0<T>(dev_ctx, x, shape_refer, &outs);
phi::funcs::StridedMemcpyWithAxis0<T, Context>(
dev_ctx, x, shape_refer, &outs);
} else {
phi::funcs::SplitFunctor<Context, T> functor;
functor(dev_ctx, x, shape_refer, axis, &outs);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册