...
 
Commits (13)
    https://gitcode.net/megvii/megengine/-/commit/813628e2a6b093422b6d111f7e87e2a8c2b583aa feat(opr): add interpolate trilinear 2023-07-04T22:04:15+08:00 Megvii Engine Team megengine@megvii.com GitOrigin-RevId: 19a96ba58bdf645ceb15655dc2f9d29085c677a7 https://gitcode.net/megvii/megengine/-/commit/b7d9cfa0f2413166f2924a3d32c4681960acd047 ci(tablegen): check src dir clean after build 2023-07-05T11:23:46+08:00 Megvii Engine Team megengine@megvii.com GitOrigin-RevId: d49a919a76093ed11b451a205e525961c5004503 https://gitcode.net/megvii/megengine/-/commit/d5477fdcff294bfe98fb421bd14e7fff71540933 feat(imperative): add xla python code 2023-07-05T23:04:02+08:00 Megvii Engine Team megengine@megvii.com GitOrigin-RevId: eeda7aadce6380110fecd00ecd3e99b384c82764 https://gitcode.net/megvii/megengine/-/commit/25434b5279012bd7bd950a46cd4e3cb7a08ce52d test(imperative/test): add xla op lower test 2023-07-05T23:04:09+08:00 Megvii Engine Team megengine@megvii.com GitOrigin-RevId: 1416ab7add4e9cd964aedc9676b0ccf3ab789837 https://gitcode.net/megvii/megengine/-/commit/c82782d046cdba450070ef132b0942a46f8efb82 refactor(imperative): change to xlatrace 2023-07-05T23:04:17+08:00 Megvii Engine Team megengine@megvii.com GitOrigin-RevId: 9403f0ce5ca5623007d5058228477418e3782232 https://gitcode.net/megvii/megengine/-/commit/9914129a07315311b522208db94b54577dd7e739 fix(imperative): fix the circular import in trace 2023-07-05T23:04:24+08:00 Megvii Engine Team megengine@megvii.com GitOrigin-RevId: 4c0ed6fc55e9af9fd245997c6277363e2b6c2cab https://gitcode.net/megvii/megengine/-/commit/4ae9dd0074df6374c4b9e77217b37546d34b486c feat(imperative): add external transform 2023-07-06T14:04:23+08:00 Megvii Engine Team megengine@megvii.com GitOrigin-RevId: e8e3ebe9c86afc9fb97900b5b5af9778cc1354e5 https://gitcode.net/megvii/megengine/-/commit/d8917c22ef4dcdc4f23db8fd6c155e5a5dfb0656 refactor(xla_trace): convert params in compile 2023-07-06T14:04:30+08:00 Megvii Engine Team megengine@megvii.com GitOrigin-RevId: c00e0592810d717aa2f786e64165ce78041a83ed https://gitcode.net/megvii/megengine/-/commit/9d535d7ac34cbb9feba71e418d7a84198e65306a feat(xla): improve lower rule 2023-07-06T14:04:38+08:00 Megvii Engine Team megengine@megvii.com GitOrigin-RevId: 55d43fe0f3666ef233612505a925662161b50bec https://gitcode.net/megvii/megengine/-/commit/0d2b4db9f0891d8bebbfee9efcb553f882b24dae ci(benchmark,convergence): fix some benchmark and convergence-test problems 2023-07-07T12:04:05+08:00 Megvii Engine Team megengine@megvii.com GitOrigin-RevId: 330efd06c3c7983885f81fd0df41bf8dcdc2bce3 https://gitcode.net/megvii/megengine/-/commit/4c7905f3d40535cefb6f948b6809af965dc647fc feat(imperative): add some xla op rules 2023-07-07T16:07:43+08:00 Megvii Engine Team megengine@megvii.com GitOrigin-RevId: 0650c75dc1e4ec9af8ae7d9ed3eca60e4681e04a https://gitcode.net/megvii/megengine/-/commit/5e013d8c57690311e532b18d154055c4d8bb7051 refactor(xla): add xla acknowledgement 2023-07-07T16:07:50+08:00 Megvii Engine Team megengine@megvii.com GitOrigin-RevId: f2fafebedfc3b520684655cff8c86b1df7c2c1cd https://gitcode.net/megvii/megengine/-/commit/281ecd0b580506038ce49b888128d0676013422d feat(xla): support IndexingMultiAxisVec and IndexingIncrMultiAxisVec 2023-07-07T21:52:28+08:00 Megvii Engine Team megengine@megvii.com GitOrigin-RevId: ca13d142ef5f6d952350f7217f1aebc2ff644dd6
......@@ -755,6 +755,11 @@ Copyright 2014 Google Inc. All rights reserved.
5. MACE
Copyright 2018 Xiaomi Inc. All rights reserved.
6. XLA
Copyright 2017 The TensorFlow Authors. All Rights Reserved.
7. JAX
Copyright 2018 The JAX Authors.
Terms of Apache License Version 2.0
---------------------------------------------------
......
......@@ -7,12 +7,8 @@ endif()
if("${CUDA_ROOT_DIR}" STREQUAL "" AND NOT "$ENV{CUDA_BIN_PATH}" STREQUAL "")
set(CUDA_ROOT_DIR $ENV{CUDA_BIN_PATH})
endif()
if("${CUDA_ROOT_DIR}" STREQUAL "")
message(
FATAL_ERROR
"Can not find CUDA, please export cuda sdk path to CUDA_ROOT_DIR or CUDA_PATH or CUDA_BIN_PATH"
)
endif()
# ${CUDA_ROOT_DIR} check removed here because users may not always keep env variable
# TODO: find_library(CUDA_ROOT_DIR) in cmake/cuda.cmake
set(MGE_CUPTI_USE_STATIC ${MGE_CUDA_USE_STATIC})
......
......@@ -245,6 +245,35 @@ protected:
size_t workspace_in_bytes);
};
class Resize3DBase : public OperatorBase {
DEF_OPR_PARAM(Resize3D);
DEF_OPR_IMPL(Resize3DBase, OperatorBase, 1, 1);
public:
using InterpolationMode = Param::InterpolationMode;
protected:
void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
};
class Resize3DForward : public Resize3DBase {
DEF_OPR_IMPL(Resize3DForward, Resize3DBase, 1, 1);
public:
virtual void exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
virtual size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) = 0;
protected:
void check_exec(
const TensorLayout& src, const TensorLayout& dst,
size_t workspace_in_bytes);
};
using Resize3D = Resize3DForward;
/**
* \brief Remap opr.
*/
......
......@@ -965,6 +965,11 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0)
'ksize_d', 0, 'ksize_h', 3, 'ksize_w', 3,
'anchor_d', 0, 'anchor_h', 1, 'anchor_w', 1))
(pdef('Resize3D')
.add_enum_alias('InterpolationMode', 'WarpPerspectiveV1', name_field='imode')
.add_enum_alias('Format', 'Convolution3D', default=1)
.add_fields('bool', 'align_corners', 'false'))
(pdef('TopK').
add_enum(
'Mode',
......
......@@ -160,6 +160,7 @@ private:
cb(GaussianBlur) \
cb(Resize) \
cb(ResizeBackward) \
cb(Resize3D) \
cb(ParamPackConcat) \
cb(MaxTensorDiff) \
cb(MaskConvForward) \
......
......@@ -150,6 +150,7 @@ DEF(GroupNormBackward, 8, true, true);
DEF(MaskedFill, 3, false, true);
DEF(MultiHeadAttnForward, 11, true, true);
DEF(MultiHeadAttnBackward, 15, true, true);
DEF(Resize3D, 2, true, false);
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -111,6 +111,39 @@ std::tuple<float, int, float, int> ResizeBase::get_nearest_linear_coord(
int ResizeBase::get_nearest_src(float scale, int size, int idx) {
return std::min(static_cast<int>(idx / scale), size - 1);
}
void Resize3DBase::check_layout_fwd(const TensorLayout& src, const TensorLayout& dst) {
auto errmsg = [&]() {
return megdnn_layout_msg(src) + ", " + ", " + megdnn_layout_msg(dst);
};
MEGDNN_MARK_USED_VAR(errmsg);
megdnn_assert(
param().format == Param::Format::NCDHW, "Resize3D only support NCDHW");
megdnn_assert(
src.ndim == 5 && dst.ndim == 5, "shape dim mismatch: %s", errmsg().c_str());
megdnn_assert(src.dtype == dst.dtype, "dtype mismatch: %s", errmsg().c_str());
megdnn_assert(
src.shape[0] == dst.shape[0], "batch size mismatch: %s", errmsg().c_str());
megdnn_assert(
src.shape[1] == dst.shape[1], "channel size mismatch: %s",
errmsg().c_str());
megdnn_assert_contiguous(src);
megdnn_assert_contiguous(dst);
auto imode = param().imode;
using IMode = param::Resize3D::InterpolationMode;
megdnn_assert(imode == IMode::INTER_LINEAR, "Resize3D only support TriLinear mode");
}
void Resize3D::check_exec(
const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) {
check_layout_fwd(src, dst);
auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -177,6 +177,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(WarpAffine);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(GaussianBlur);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Resize);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ResizeBackward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Resize3D);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ParamPackConcat);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(MaxTensorDiff);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(MaskConvForward);
......
......@@ -26,6 +26,15 @@ void backward_data_proxy(
int C, int IH, int IW, int OH, int OW, cudaStream_t stream);
} // namespace resize
namespace resize3d {
template <typename ctype>
void resize3d_forward(
const bool align_corners, const ctype* iptr, ctype* optr, const int N,
const int C, const int ID, const int IH, const int IW, const int OD,
const int OH, const int OW, cudaStream_t stream);
} // namespace resize3d
} // namespace cuda
} // namespace megdnn
......
......@@ -168,4 +168,41 @@ void ResizeImpl::exec(
}
}
size_t Resize3DImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) {
return 0;
}
void Resize3DImpl::exec(
_megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_workspace workspace) {
check_exec(src.layout, dst.layout, workspace.size);
size_t out_depth = dst.layout.shape[2];
size_t out_height = dst.layout.shape[3];
size_t out_width = dst.layout.shape[4];
size_t in_depth = src.layout.shape[2];
size_t in_height = src.layout.shape[3];
size_t in_width = src.layout.shape[4];
bool align_corners = param().align_corners;
auto stream = cuda_stream(this->handle());
if (src.layout.dtype == dtype::Float32{}) {
resize3d::resize3d_forward(
align_corners, src.ptr<dt_float32>(), dst.ptr<dt_float32>(),
src.layout[0], src.layout[1], in_depth, in_height, in_width, out_depth,
out_height, out_width, stream);
#if !MEGDNN_DISABLE_FLOAT16
} else if (src.layout.dtype == dtype::Float16{}) {
resize3d::resize3d_forward(
align_corners, src.ptr<dt_float16>(), dst.ptr<dt_float16>(),
src.layout[0], src.layout[1], in_depth, in_height, in_width, out_depth,
out_height, out_width, stream);
#endif
} else {
megdnn_throw(ssprintf(
"unsupported dtype: %s for Resize3D", src.layout.dtype.name()));
}
}
// vim: syntax=cpp.doxygen
......@@ -308,6 +308,156 @@ DNN_INC_FLOAT16(INST(dt_float16))
INST(int8_t);
#undef INST
} // namespace resize
namespace resize3d {
__device__ __forceinline__ static float pixel_get_src_index(
float scale, int64_t dst_index, bool align_corners) {
if (align_corners) {
return scale * dst_index;
} else {
float src_idx = scale * (dst_index + 0.5f) - 0.5f;
return src_idx < 0.f ? 0.f : src_idx;
}
}
__device__ __forceinline__ static size_t index_getter(
int n, int c, int d, int h, int w, const int N, const int C, const int D,
const int H, const int W) {
return n * C * D * H * W + c * D * H * W + d * H * W + h * W + w;
}
template <typename ctype>
__global__ void trilinear_forward(
const int num_kernels, const float rdepth, const float rheight,
const float rwidth, const bool align_corners, const ctype* iptr, ctype* optr,
const int N, const int C, const int ID, const int IH, const int IW,
const int OD, const int OH, const int OW) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < num_kernels) {
const int w2 = (index % (OH * OW)) % OW;
const int h2 = (index % (OH * OW)) / OW;
const int t2 = index / (OH * OW);
if (ID == OD && IH == OH && IW == OW) {
const int t1 = t2;
const int h1 = h2;
const int w1 = w2;
for (int n = 0; n < N; n++) {
for (int c = 0; c < C; ++c) {
const ctype val =
iptr[index_getter(n, c, t1, h1, w1, N, C, ID, IH, IW)];
optr[index_getter(n, c, t2, h2, w2, N, C, OD, OH, OW)] = val;
}
}
return;
}
const float t1r = pixel_get_src_index(rdepth, t2, align_corners);
const int t1 = t1r;
const int t1p = (t1 < ID - 1) ? 1 : 0;
const float t1lambda = t1r - t1;
const float t0lambda = static_cast<float>(1) - t1lambda;
const float h1r = pixel_get_src_index(rheight, h2, align_corners);
const int h1 = h1r;
const int h1p = (h1 < IH - 1) ? 1 : 0;
const float h1lambda = h1r - h1;
const float h0lambda = static_cast<float>(1) - h1lambda;
const float w1r = pixel_get_src_index(rwidth, w2, align_corners);
const int w1 = w1r;
const int w1p = (w1 < IW - 1) ? 1 : 0;
const float w1lambda = w1r - w1;
const float w0lambda = static_cast<float>(1) - w1lambda;
for (int n = 0; n < N; n++) {
for (int c = 0; c < C; ++c) {
const float val =
t0lambda *
(h0lambda * (w0lambda * iptr[index_getter(
n, c, t1, h1, w1, N, C,
ID, IH, IW)] +
w1lambda * iptr[index_getter(
n, c, t1, h1, w1 + w1p,
N, C, ID, IH, IW)]) +
h1lambda *
(w0lambda * iptr[index_getter(
n, c, t1, h1 + h1p, w1, N,
C, ID, IH, IW)] +
w1lambda *
iptr[index_getter(
n, c, t1, h1 + h1p, w1 + w1p,
N, C, ID, IH, IW)])) +
t1lambda *
(h0lambda *
(w0lambda * iptr[index_getter(
n, c, t1 + t1p, h1, w1, N,
C, ID, IH, IW)] +
w1lambda *
iptr[index_getter(
n, c, t1 + t1p, h1, w1 + w1p,
N, C, ID, IH, IW)]) +
h1lambda *
(w0lambda * iptr[index_getter(
n, c, t1 + t1p, h1 + h1p,
w1, N, C, ID, IH, IW)] +
w1lambda * iptr[index_getter(
n, c, t1 + t1p, h1 + h1p,
w1 + w1p, N, C, ID, IH,
IW)]));
optr[index_getter(n, c, t2, h2, w2, N, C, OD, OH, OW)] =
static_cast<ctype>(val);
}
}
}
}
__host__ __forceinline__ static float get_scale(
int input_size, int output_size, bool align_corners) {
if (align_corners) {
if (output_size > 1) {
return static_cast<float>(input_size - 1) / (output_size - 1);
} else {
return 0.f;
}
} else {
return static_cast<float>(input_size) / output_size;
}
}
template <typename ctype>
void resize3d_forward(
const bool align_corners, const ctype* iptr, ctype* optr, const int N,
const int C, const int ID, const int IH, const int IW, const int OD,
const int OH, const int OW, cudaStream_t stream) {
const size_t num_kernels = OD * OH * OW;
const size_t num_threads = 512;
float rdepth = get_scale(ID, OD, align_corners);
float rheight = get_scale(IH, OH, align_corners);
float rwidth = get_scale(IW, OW, align_corners);
trilinear_forward<ctype>
<<<(num_kernels + num_threads - 1) / num_threads, num_threads, 0, stream>>>(
num_kernels, rdepth, rheight, rwidth, align_corners, iptr, optr, N,
C, ID, IH, IW, OD, OH, OW);
}
#define INST(ctype) \
template void resize3d_forward( \
const bool, const ctype*, ctype*, const int, const int, const int, \
const int, const int, const int, const int, const int, cudaStream_t);
INST(float)
DNN_INC_FLOAT16(INST(dt_float16))
#undef INST
} // namespace resize3d
} // namespace cuda
} // namespace megdnn
......
......@@ -24,6 +24,16 @@ public:
const TensorLayout& diff, const TensorLayout& grad) override;
};
class Resize3DImpl final : public Resize3D {
public:
using Resize3D::Resize3D;
void exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) override;
};
} // namespace cuda
} // namespace megdnn
......
......@@ -545,4 +545,149 @@ void ResizeBackwardImpl::exec(
}
}
template <typename ctype>
void Resize3DImpl::kern_naive(
const float rdepth, const float rheight, const float rwidth,
const bool align_corners, const ctype* iptr, ctype* optr, const int N,
const int C, const int ID, const int IH, const int IW, const int OD,
const int OH, const int OW) {
auto pixel_get_src_index = [](float scale, int64_t dst_index, bool align_corners) {
if (align_corners) {
return scale * dst_index;
} else {
float src_idx = scale * (dst_index + 0.5f) - 0.5f;
return src_idx < 0.f ? 0.f : src_idx;
}
};
auto i_index = [&](int in, int ic, int id, int ih, int iw) -> int {
return in * C * ID * IH * IW + ic * ID * IH * IW + id * IH * IW + ih * IW + iw;
};
auto o_index = [&](int in, int ic, int id, int ih, int iw) -> int {
return in * C * OD * OH * OW + ic * OD * OH * OW + id * OH * OW + ih * OW + iw;
};
for (int t2 = 0; t2 < OD; ++t2) {
for (int h2 = 0; h2 < OH; ++h2) {
for (int w2 = 0; w2 < OW; ++w2) {
const float t1r = pixel_get_src_index(rdepth, t2, align_corners);
const int t1 = t1r;
const int t1p = (t1 < ID - 1) ? 1 : 0;
const float t1lambda = t1r - t1;
const float t0lambda = static_cast<float>(1) - t1lambda;
const float h1r = pixel_get_src_index(rheight, h2, align_corners);
const int h1 = h1r;
const int h1p = (h1 < IH - 1) ? 1 : 0;
const float h1lambda = h1r - h1;
const float h0lambda = static_cast<float>(1) - h1lambda;
const float w1r = pixel_get_src_index(rwidth, w2, align_corners);
const int w1 = w1r;
const int w1p = (w1 < IW - 1) ? 1 : 0;
const float w1lambda = w1r - w1;
const float w0lambda = static_cast<float>(1) - w1lambda;
for (int n = 0; n < N; ++n) {
for (int c = 0; c < C; ++c) {
const float val =
t0lambda *
(h0lambda *
(w0lambda *
iptr[i_index(
n, c, t1, h1, w1)] +
w1lambda * iptr[i_index(
n, c, t1, h1,
w1 + w1p)]) +
h1lambda *
(w0lambda * iptr[i_index(
n, c, t1, h1 + h1p,
w1)] +
w1lambda * iptr[i_index(
n, c, t1, h1 + h1p,
w1 + w1p)])) +
t1lambda *
(h0lambda *
(w0lambda * iptr[i_index(
n, c, t1 + t1p, h1,
w1)] +
w1lambda * iptr[i_index(
n, c, t1 + t1p, h1,
w1 + w1p)]) +
h1lambda * (w0lambda * iptr[i_index(
n, c, t1 + t1p,
h1 + h1p, w1)] +
w1lambda * iptr[i_index(
n, c, t1 + t1p,
h1 + h1p,
w1 + w1p)]));
optr[o_index(n, c, t2, h2, w2)] = static_cast<ctype>(val);
}
}
}
}
}
}
#define INST(ctype) \
template void Resize3DImpl::kern_naive( \
const float, const float, const float, const bool, const ctype*, ctype*, \
const int, const int, const int, const int, const int, const int, \
const int, const int)
INST(dt_float32);
DNN_INC_FLOAT16(INST(dt_float16));
#undef INST
void Resize3DImpl::exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) {
check_exec(src.layout, dst.layout, workspace.size);
bool align_corners = param().align_corners;
size_t N = src.layout.shape[0];
size_t C = src.layout.shape[1];
size_t OD = dst.layout.shape[2];
size_t OH = dst.layout.shape[3];
size_t OW = dst.layout.shape[4];
size_t ID = src.layout.shape[2];
size_t IH = src.layout.shape[3];
size_t IW = src.layout.shape[4];
auto get_scale = [](int input_size, int output_size, bool align_corners) -> float {
if (align_corners) {
if (output_size > 1) {
return static_cast<float>(input_size - 1) / (output_size - 1);
} else {
return 0.f;
}
} else {
return static_cast<float>(input_size) / output_size;
}
};
float rdepth = get_scale(ID, OD, align_corners);
float rheight = get_scale(IH, OH, align_corners);
float rwidth = get_scale(IW, OW, align_corners);
if (src.layout.dtype == dtype::Float32{}) {
Resize3DImpl::kern_naive(
rdepth, rheight, rwidth, align_corners, src.ptr<dt_float32>(),
dst.ptr<dt_float32>(), N, C, ID, IH, IW, OD, OH, OW);
#if !MEGDNN_DISABLE_FLOAT16
} else if (src.layout.dtype == dtype::Float16{}) {
Resize3DImpl::kern_naive(
rdepth, rheight, rwidth, align_corners, src.ptr<dt_float16>(),
dst.ptr<dt_float16>(), N, C, ID, IH, IW, OD, OH, OW);
#endif
} else {
megdnn_throw(ssprintf(
"unsupported dtype: %s for Resize3D", src.layout.dtype.name()));
}
}
size_t Resize3DImpl::get_workspace_in_bytes(const TensorLayout&, const TensorLayout&) {
return 0;
}
// vim: syntax=cpp.doxygen
......@@ -83,6 +83,24 @@ private:
int N, int C, int IH, int IW, int OH, int OW);
};
class Resize3DImpl final : public Resize3D {
public:
using Resize3D::Resize3D;
void exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) override;
private:
template <typename ctype>
void kern_naive(
const float rdepth, const float rheight, const float rwidth,
const bool align_corners, const ctype* iptr, ctype* optr, const int N,
const int C, const int ID, const int IH, const int IW, const int OD,
const int OH, const int OW);
};
} // namespace naive
} // namespace megdnn
......
......@@ -193,6 +193,26 @@ TEST_F(CUDA, RESIZE_BACKWARD) {
}
}
TEST_F(CUDA, RESIZE3D_NCDHW) {
using IMode = param::Resize3D::InterpolationMode;
using Format = param::Resize3D::Format;
auto ac_param = param::Resize3D{IMode::LINEAR, Format::NCDHW, true};
auto nac_param = param::Resize3D{IMode::LINEAR, Format::NCDHW, false};
auto run = [&](DType d, TensorShape ishape, TensorShape oshape) {
Checker<Resize3D> checker(handle_cuda());
checker.set_param(ac_param).set_dtype(0, d).set_dtype(1, d).execs(
{ishape, oshape});
checker.set_param(nac_param).execs({ishape, oshape});
};
for (auto&& dtype : std::vector<DType>{dtype::Float32(), dtype::Float16()}) {
run(dtype, {1, 1, 2, 2, 2}, {1, 1, 4, 4, 4});
run(dtype, {2, 2, 2, 3, 4}, {2, 2, 2, 3, 6});
run(dtype, {2, 2, 2, 3, 4}, {2, 2, 3, 4, 5});
}
}
#if MEGDNN_WITH_BENCHMARK
TEST_F(CUDA, BENCHMARK_RESIZE_CV) {
......
......@@ -61,3 +61,95 @@ TEST_F(NAIVE, RESIZE_NCHW4) {
.execs({arg.src, arg.dst});
}
}
TEST_F(NAIVE, RESIZE3D_NCDHW) {
using IMode = param::Resize3D::InterpolationMode;
using Format = param::Resize3D::Format;
auto ac_param = param::Resize3D{IMode::LINEAR, Format::NCDHW, true};
auto nac_param = param::Resize3D{IMode::LINEAR, Format::NCDHW, false};
Checker<Resize3D> checker(handle());
checker.set_param(nac_param).exect(
Testcase{
TensorValue(
{1, 1, 2, 2, 2}, dtype::Float32(),
{0., 1., 2., 3., 4., 5., 6., 7.}),
{}},
Testcase{
{},
TensorValue(
{1, 1, 4, 4, 4}, dtype::Float32(),
{0., 0.25, 0.75, 1., 0.5, 0.75, 1.25, 1.5, 1.5, 1.75,
2.25, 2.5, 2., 2.25, 2.75, 3., 1., 1.25, 1.75, 2.,
1.5, 1.75, 2.25, 2.5, 2.5, 2.75, 3.25, 3.5, 3., 3.25,
3.75, 4., 3., 3.25, 3.75, 4., 3.5, 3.75, 4.25, 4.5,
4.5, 4.75, 5.25, 5.5, 5., 5.25, 5.75, 6., 4., 4.25,
4.75, 5., 4.5, 4.75, 5.25, 5.5, 5.5, 5.75, 6.25, 6.5,
6., 6.25, 6.75, 7.})});
checker.set_param(ac_param).exect(
Testcase{
TensorValue(
{1, 1, 2, 2, 2}, dtype::Float32(),
{0., 1., 2., 3., 4., 5., 6., 7.}),
{}},
Testcase{
{},
TensorValue(
{1, 1, 4, 4, 4}, dtype::Float32(),
{0., 0.3333333, 0.6666667, 1., 0.6666667,
1., 1.3333333, 1.6666666, 1.3333334, 1.6666667,
1.9999999, 2.3333333, 2., 2.3333333, 2.6666665,
3., 1.3333334, 1.6666666, 2.0000002, 2.3333335,
2., 2.333333, 2.6666667, 2.9999998, 2.6666665,
3., 3.3333333, 3.6666665, 3.3333333, 3.6666665,
4., 4.3333335, 2.6666667, 3., 3.3333337,
3.6666667, 3.3333335, 3.6666663, 4., 4.333333,
3.9999998, 4.333333, 4.6666665, 5., 4.6666665,
5., 5.3333335, 5.666667, 4., 4.333333,
4.666667, 5., 4.6666665, 4.9999995, 5.3333335,
5.6666665, 5.333333, 5.6666665, 6., 6.3333335,
6., 6.333333, 6.666667, 7.})});
checker.set_param(nac_param).exect(
Testcase{
TensorValue(
{1, 1, 2, 2, 2}, dtype::Float16(),
{0., 1., 2., 3., 4., 5., 6., 7.}),
{}},
Testcase{
{},
TensorValue(
{1, 1, 4, 4, 4}, dtype::Float16(),
{0., 0.25, 0.75, 1., 0.5, 0.75, 1.25, 1.5, 1.5, 1.75,
2.25, 2.5, 2., 2.25, 2.75, 3., 1., 1.25, 1.75, 2.,
1.5, 1.75, 2.25, 2.5, 2.5, 2.75, 3.25, 3.5, 3., 3.25,
3.75, 4., 3., 3.25, 3.75, 4., 3.5, 3.75, 4.25, 4.5,
4.5, 4.75, 5.25, 5.5, 5., 5.25, 5.75, 6., 4., 4.25,
4.75, 5., 4.5, 4.75, 5.25, 5.5, 5.5, 5.75, 6.25, 6.5,
6., 6.25, 6.75, 7.})});
checker.set_param(ac_param).exect(
Testcase{
TensorValue(
{1, 1, 2, 2, 2}, dtype::Float16(),
{0., 1., 2., 3., 4., 5., 6., 7.}),
{}},
Testcase{
{},
TensorValue(
{1, 1, 4, 4, 4}, dtype::Float16(),
{0., 0.3333333, 0.6666667, 1., 0.6666667,
1., 1.3333333, 1.6666666, 1.3333334, 1.6666667,
1.9999999, 2.3333333, 2., 2.3333333, 2.6666665,
3., 1.3333334, 1.6666666, 2.0000002, 2.3333335,
2., 2.333333, 2.6666667, 2.9999998, 2.6666665,
3., 3.3333333, 3.6666665, 3.3333333, 3.6666665,
4., 4.3333335, 2.6666667, 3., 3.3333337,
3.6666667, 3.3333335, 3.6666663, 4., 4.333333,
3.9999998, 4.333333, 4.6666665, 5., 4.6666665,
5., 5.3333335, 5.666667, 4., 4.333333,
4.666667, 5., 4.6666665, 4.9999995, 5.3333335,
5.6666665, 5.333333, 5.6666665, 6., 6.3333335,
6., 6.333333, 6.666667, 7.})});
}
......@@ -474,7 +474,8 @@ def interpolate(
size: the size of the output tensor. Default: None
scale_factor: scaling factor of the output tensor. Default: None
mode: interpolation methods, acceptable values are:
"bilinear", "linear", "bicubic" and "nearest". Default: "bilinear"
"bilinear", "linear", "trilinear", "bicubic" and "nearest". Default: "bilinear"
"trilinear" is valid only when inp is a 5D-tensor
align_corners: This only has an effect when ``mode``
is "bilinear" or "linear". Geometrically, we consider the pixels of the input
and output as squares rather than points. If set to ``True``, the input
......@@ -500,9 +501,9 @@ def interpolate(
>>> np.testing.assert_allclose(out.numpy(), out2.numpy())
"""
mode = mode.lower()
if mode not in ["bilinear", "linear", "bicubic", "nearest"]:
if mode not in ["bilinear", "linear", "trilinear", "bicubic", "nearest"]:
raise ValueError("unsupported interpolate mode: {}".format(mode))
if mode not in ["bilinear", "linear"]:
if mode not in ["bilinear", "linear", "trilinear"]:
if align_corners is not None:
raise ValueError(
"align_corners option can only be set in the bilinear/linear interpolating mode"
......@@ -514,14 +515,22 @@ def interpolate(
if mode == "linear":
inp = expand_dims(inp, 3)
if inp.ndim != 4:
raise ValueError("shape of input tensor must correspond to the operartion mode")
if mode == "trilinear":
assert (
inp.ndim == 5
), "under trilinear mode, input tensor must have 5 dimensions"
else:
assert (
inp.ndim == 4
), "shape of input tensor must correspond to the operartion mode"
def get_dsize(scale_factor):
if isinstance(scale_factor, (float, int)):
scale_factor = float(scale_factor)
if mode == "linear":
scale_factor = (scale_factor, float(1))
elif mode == "trilinear":
scale_factor = (scale_factor, scale_factor, scale_factor)
else:
scale_factor = (scale_factor, scale_factor)
else:
......@@ -530,21 +539,28 @@ def interpolate(
"under linear mode, scale_factor can only be single value"
)
assert len(scale_factor) == 2, "shape of scale_factor must be equal to (2, )"
assert isinstance(scale_factor[0], float) and isinstance(
scale_factor[1], float
), "scale_factor must be float type"
dsize = tuple(
if mode == "trilinear":
assert (
len(scale_factor) == 3
), f"shape of scale_factor of interpolate-{mode} must be equal to (3, )"
else:
assert (
len(scale_factor) == 2
), f"shape of scale_factor of interpolate-{mode} must be equal to (2, )"
assert all(
isinstance(x, (float, int)) for x in scale_factor
), f"scale_factor of interpolate must be float/int type"
dsize = [
floor(
Tensor(
inp.shape[i + 2] * scale_factor[i],
inp.shape[i + 2] * float(scale_factor[i]),
dtype="float32",
device=inp.device,
)
)
for i in range(2)
)
dsize = concat([dsize[0], dsize[1]], axis=0)
for i in range(len(scale_factor))
]
dsize = concat(dsize, axis=0)
return dsize
if size is None:
......@@ -557,13 +573,24 @@ def interpolate(
raise ValueError("scale_factor must be None when size is provided")
if isinstance(size, int):
size = (size, 1)
if mode == "trilinear":
size = (size, 1, 1)
else:
size = (size, 1)
else:
if mode == "linear":
raise ValueError("under linear mode, size can only be single value")
dsize = size
if not align_corners:
if mode == "trilinear":
if inp.dtype == np.float16:
inp = inp.astype("float32")
op = builtin.Resize3D(
imode="linear", format="NCDHW", align_corners=align_corners
)
shape = astensor1d(dsize, inp, dtype="int32", device=inp.device)
(ret,) = apply(op, inp, shape)
elif not align_corners:
# fastpath for interpolate
mode_map = {
"linear": "linear",
......
......@@ -75,7 +75,9 @@ def _process_fwd_bwd_trace_result(fwd, bwd, inp_grad_map, out_grad_map):
def check_external(trace_obj):
for var in trace_obj.vars:
if var.kind == "external" and not var.inp_mark:
raise RuntimeError("have unknown input in trace result")
raise RuntimeError(
"have unknown input in trace result, maybe you can set `capture_as_const=True` when trace"
)
check_external(fwd)
check_external(bwd)
......
......@@ -579,7 +579,7 @@ class trace:
if not self._trace.compiled():
outlist, self.outdef = tree_flatten(outputs)
for i, out in enumerate(outlist):
assert isinstance(out, RawTensor)
assert isinstance(out, RawTensor), f"get out of type {type(out)}"
outlist[i] = get_marked_output_tensor(self.output_num, out)
del out
self.out_list.append(self.output_num)
......
from collections import OrderedDict, defaultdict
import numpy as np
from .. import _full_sync, tensor
from ..core._imperative_rt import CompNode
from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..core._imperative_rt.core2 import (
is_external_convert,
set_external_convert,
set_external_convert_hook,
set_py_external_type,
unset_external_convert,
)
from ..core._trace_option import set_use_xla_backend
from ..device import get_default_device
from ..utils.dlpack import from_dlpack, to_dlpack
from .tracing import trace
try:
from mge_xlalib.xla_extension import ArrayImpl
from ..xla.lib import xla_client as xc
except ImportError:
except ImportError as e:
pass
xla_client_compute_stream = None
def apply_external_convert_hook(input, cn):
stream = xla_client_compute_stream
assert isinstance(input, ArrayImpl)
dlpack_capsule = xc._xla.buffer_to_dlpack_managed_tensor(input, take_ownership=True)
output = from_dlpack(dlpack_capsule, stream).to(cn, _borrow=True)
return output
class xla_trace(trace):
r"""Wraps a callable, and provides accelerated evaluation compiled by xla.
......@@ -53,6 +74,12 @@ class xla_trace(trace):
def __init__(self, function, *, without_host=True, symbolic_shape=False, **kwargs):
assert without_host, "xla trace only support without host mode"
assert not symbolic_shape, "xla doesn't support dynamic shape currently"
set_external_convert_hook(apply_external_convert_hook)
set_py_external_type(ArrayImpl)
set_external_convert()
super().__init__(
function, without_host=without_host, symbolic_shape=symbolic_shape, **kwargs
)
......@@ -63,24 +90,57 @@ class xla_trace(trace):
def unset_env(self):
set_use_xla_backend(self.orig_use_xla)
def convert_params_to_xla(self):
from ..utils.module_utils import get_expand_structure
from ..tensor import Tensor
backend = self.xla_exec.backend
devices = backend.local_devices()
default_cn = CompNode(get_default_device())
_, device_id, _ = default_cn.physical_locator
device_index = (
0 if len(devices) == 0 else [d.id for d in devices].index(device_id)
)
device = devices[device_index]
for attr, _ in self.attr_to_key.items():
param = get_expand_structure(attr[0], attr[1])
param._reset(param.to("cpux"))
for tensor, _ in self.opt_param_dict.items():
tensor._reset(tensor.to("cpux"))
def as_xla_array(tensor, backend, device):
np_array = tensor.numpy()
if np_array.shape == ():
np_array = np_array[np.newaxis]
xla_array = backend.buffer_from_pyval(np_array, device)
tensor._reset(Tensor(xla_array, device=default_cn))
for attr, _ in self.attr_to_key.items():
param = get_expand_structure(attr[0], attr[1])
as_xla_array(param, backend, device)
for tensor, _ in self.opt_param_dict.items():
as_xla_array(tensor, backend, device)
def compile(self):
from ..xla import build_xla
from ..traced_module.pytree import SUPPORTED_LEAF_TYPE, register_supported_type
from ..utils.module_utils import get_expand_structure
from ..xla.device import get_xla_backend_and_device
from ..tensor import Tensor
from ..distributed import get_mm_server_addr, is_distributed
assert self.traced
if self.overall:
for attr, _ in self.attr_to_key.items():
param = get_expand_structure(attr[0], attr[1])
param._reset(param.to("cpux"))
for tensor, _ in self.opt_param_dict.items():
tensor._reset(tensor.to("cpux"))
self.xla_exec, self.inp_ids, self.out_ids = build_xla(
self, return_with_io=True, return_device_array=True
self,
return_with_io=True,
return_device_array=True,
ip=get_mm_server_addr()[0] if is_distributed() else None,
port=get_mm_server_addr()[1] + 1 if is_distributed() else None,
)
if self.overall:
self.convert_params_to_xla()
id2inpidx = defaultdict(list)
id2outidx = defaultdict(list)
for idx, id in enumerate(self.inp_ids):
......@@ -137,11 +197,13 @@ class xla_trace(trace):
return inp_list
def to_dlpack(self, x, take_ownership: bool = True):
from ..xla.lib import xla_client as xc
return xc._xla.buffer_to_dlpack_managed_tensor(x, take_ownership=take_ownership)
def execute(self, *args, **kwargs):
from ..traced_module.pytree import tree_flatten
from ..tensor import Tensor
from ..traced_module.pytree import tree_flatten
from ..utils.module_utils import get_expand_structure
inputs, _ = tree_flatten((args, kwargs))
......@@ -159,6 +221,8 @@ class xla_trace(trace):
arrays = self.prepare_xla_inputs(arrays)
outputs = self.xla_exec(*arrays)
global xla_client_compute_stream
xla_client_compute_stream = xla_stream
return_vals = []
for i in self.out_list:
if i == -1:
......@@ -168,28 +232,25 @@ class xla_trace(trace):
return_vals.append(outputs[self.outkey2idx[i]])
keeped_features = []
for i in self.keeped_activation:
capsule = self.to_dlpack(outputs[self.outkey2idx[i]])
t = from_dlpack(capsule, xla_stream).to(cn, _borrow=True)
keeped_features.append(t)
keeped_features.append(tensor(outputs[self.outkey2idx[i]], device=cn))
out_tensors = []
for array in return_vals:
if array is not None:
capsule = self.to_dlpack(array)
t = from_dlpack(capsule, xla_stream)
out_tensors.append(t.to(cn, _borrow=True))
t = tensor(array, device=cn)
out_tensors.append(t)
else:
out_tensors.append(array)
if self.overall:
for attr, key in self.update_param_dict.items():
param = get_expand_structure(attr[0], attr[1])
xla_array = outputs[self.outkey2idx[key]]
capsule = self.to_dlpack(xla_array)
param._reset(from_dlpack(capsule).to(cn, _borrow=True))
t = tensor(xla_array, device=cn)
param._reset(t)
for state, key in self.update_opt_param_dict.items():
xla_array = outputs[self.outkey2idx[key]]
capsule = self.to_dlpack(xla_array)
state._reset(from_dlpack(capsule).to(cn, _borrow=True))
t = tensor(xla_array, device=cn)
state._reset(t)
rst = (
self.outdef.unflatten(out_tensors)
if hasattr(self, "outdef")
......
......@@ -224,11 +224,11 @@ class GeneralNorm(Module):
Examples:
>>> import numpy as np
>>> inp = Tensor(np.arange(2 * 3 * 4 * 4).astype(np.float32).reshape(2, 3, 4, 4))
>>> m = M.GeneralNorm((2, 4), (0, 2))
>>> m = M.GeneralNorm((2, 3), (0, 1))
>>> out = m(inp)
>>> out.numpy().shape
(2, 3, 4, 4)
>>> m = M.GeneralNorm((3, 4), (1, -1)) # Please be careful.
>>> m = M.GeneralNorm((3, 4), (1, -2)) # Please be careful.
>>> out = m(inp)
>>> out.numpy().shape
(2, 3, 4, 4)
......
# some code of this directory is from jax: https://github.com/google/jax
from .build import build_xla
import os
from ..distributed import get_rank, get_world_size, is_distributed
from .compile import MeshComputation, PmapComputation
from .device import get_xla_backend_and_device
from .distribute import initialize
from .ir_utils import DropoutMaskCanonicalizer, RngKeyAdder, TraceResult
from .lib import xla_client as xc
from .lower import lower
from .sharding import OpShardingSharding, _is_unspecified, make_unspec_sharding
xla_extention = xc._xla
xe = xla_extention
Backend = xe.Client
def build_xla(
mge_traced,
func_name=None,
device=None,
keep_unused=True,
donate_invars=None,
verbose=int(os.environ.get("MGE_VERBOSE_XLA_IR", "0")),
return_with_io=False,
return_device_array=False,
ip: str = None,
port: int = None,
):
assert device == None, "cannot specify device now"
assert keep_unused == True, "keep_unused error"
assert donate_invars == None, "donate_invars error"
# normalize megengine trace result for lowering
tr = TraceResult(mge_traced, func_name)
tr = RngKeyAdder()(tr)
tr = DropoutMaskCanonicalizer()(tr)
if verbose and get_rank() == 0:
print("================ Mge Trace Result ================")
print(tr)
in_is_global = (True,) * len(tr.inputs)
kept_var_idx = set(range(len(tr.inputs))) if keep_unused else set()
# init for xla distributed and setup device
if is_distributed():
initialize(ip, port, get_world_size(), get_rank(), [get_rank()])
backend, device_assignment, platform = get_xla_backend_and_device(device)
module, keepalive, host_callbacks = lower(
tr, backend, platform, None, None, donate_invars,
)
if not is_distributed():
# setup sharding information
in_shardings = make_unspec_sharding(tr.inputs)
out_shardings = make_unspec_sharding(tr.outputs)
in_shardings = tuple(
OpShardingSharding.get_replicated(device_assignment)
if _is_unspecified(i)
else i
for i in in_shardings
)
computation = MeshComputation(
tr.func_name,
module,
donated_invars=donate_invars,
trace_result=tr,
mesh=None,
in_shardings=in_shardings,
out_shardings=out_shardings,
spmd_lowering=False,
tuple_args=False, # for tpu
in_is_global=in_is_global,
auto_spmd_lowering=False,
unordered_effects=[],
ordered_effects=[],
host_callbacks=host_callbacks,
keepalive=keepalive,
kept_var_idx=kept_var_idx,
backend=backend,
device_assignment=device_assignment,
committed=False, # unknown
pmap_nreps=1,
return_device_array=return_device_array,
)
else:
computation = PmapComputation(
tr.func_name,
module,
trace_result=tr,
unordered_effects=[],
ordered_effects=[],
tuple_args=False, # for tpu
in_is_global=in_is_global,
host_callbacks=host_callbacks,
keepalive=keepalive,
kept_var_idx=kept_var_idx,
backend=backend,
devices=None,
return_device_array=return_device_array,
world_size=get_world_size(),
rank=get_rank(),
)
if verbose and get_rank() == 0:
print("================ XLA HLO IR ================")
print(computation.as_text())
compiled = computation.compile()
if verbose and get_rank() == 0:
print("================ XLA Execute Plan ================")
print(compiled.as_text())
ret = compiled.unsafe_call
if return_with_io:
return ret, tr.inputs, tr.outputs
return ret
此差异已折叠。
import itertools as it
from typing import Sequence, Tuple, Union
import numpy as np
from ..core._imperative_rt.common import CompNode
from ..tensor import Parameter as MgeParameter
from ..tensor import Tensor as MgeTensor
from .dtype import (
_np_types,
_python_scalar_dtypes,
_scalar_type_to_dtype,
canonicalize_arg,
)
from .lib import xla_bridge as xb
from .lib import xla_client as xc
from .utils import safe_zip
xla_extention = xc._xla
xe = xla_extention
Backend = xe.Client
device_put_handlers = {}
def _device_put_nparray(x, device):
backend = xb.get_device_backend(device)
return (backend.buffer_from_pyval(x, device),)
def _device_put_scalar(x, device):
def cvt_scalar_to_nparray(x, dtype=None):
if dtype is None and type(x) in _python_scalar_dtypes:
dtype = _scalar_type_to_dtype(type(x), x)
return np.asarray(x, dtype)
return _device_put_nparray(cvt_scalar_to_nparray(x), device)
def _device_put_device_array(x, device):
assert False
def _device_put_mge_tensor(x, device):
x = x.numpy()
return _device_put_nparray(x, device)
for nt in _np_types:
device_put_handlers[nt] = _device_put_nparray
for sc in _python_scalar_dtypes:
device_put_handlers[nt] = _device_put_scalar
device_put_handlers[xc._xla.DeviceArray] = _device_put_device_array
device_put_handlers[MgeTensor] = _device_put_mge_tensor
device_put_handlers[MgeParameter] = _device_put_mge_tensor
def _device_put_impl(x, device):
x = canonicalize_arg(x)
return device_put_handlers[type(x)](x, device)
def device_put(x, devices: Sequence[xb.xla_client.Device], replicate: bool = False):
if replicate:
return list(
it.chain.from_iterable(_device_put_impl(x, device) for device in devices)
)
else:
return list(
it.chain.from_iterable(
_device_put_impl(val, device) for val, device in safe_zip(x, devices)
)
)
def get_xla_backend_and_device(device=None) -> Tuple[Backend, Sequence[xc.Device]]:
assert device is None, "device assignment is not supported yet"
device_assignment = [xb.local_devices()[0]]
backend = xb.get_device_backend(device_assignment[0])
platform = backend.platform
platform = xb.canonicalize_platform(platform)
assert xb.is_known_platform(platform), f"{platform} is not known yet"
assert platform == "cuda", f"only cuda platfrom is supportted, but get {platform}"
return backend, device_assignment, platform
import atexit
import warnings
from typing import Any, Optional, Sequence, Union
from .lib import xla_client as xc
xla_extention = xc._xla
xe = xla_extention
class State:
process_id: int = 0
ip: str = None
port: int = None
service: Optional[Any] = None
client: Optional[Any] = None
preemption_sync_manager: Optional[Any] = None
visible_devices: Optional[str] = "all"
def initialize(
self,
ip: str,
port: int,
num_processes: int,
process_id: int,
local_device_ids: Optional[Union[int, Sequence[int]]] = None,
):
coordinator_address = ip + ":" + str(port)
if local_device_ids is None:
local_device_ids = [process_id]
elif isinstance(local_device_ids, int):
local_device_ids = [local_device_ids]
else:
local_device_ids = list(local_device_ids)
assert local_device_ids == [process_id], f"{local_device_ids} .vs {process_id}"
self.ip = ip
self.port = port
self.visible_devices = ",".join(str(x) for x in local_device_ids)
self.process_id = process_id
if process_id == 0:
if self.service is not None:
raise RuntimeError("distributed.initialize should only be called once.")
self.service = xe.get_distributed_runtime_service(
coordinator_address, num_processes, use_coordination_service=True
)
if self.client is not None:
raise RuntimeError("distributed.initialize should only be called once.")
# Set init_timeout to 5 min to leave time for all the processes to connect
self.client = xe.get_distributed_runtime_client(
coordinator_address,
process_id,
use_coordination_service=True,
init_timeout=300,
)
self.client.connect()
self.initialize_preemption_sync_manager()
def shutdown(self):
if self.client:
self.client.shutdown()
self.client = None
if self.service:
self.service.shutdown()
self.service = None
if self.preemption_sync_manager:
self.preemption_sync_manager = None
def initialize_preemption_sync_manager(self):
if self.preemption_sync_manager is not None:
raise RuntimeError(
"Preemption sync manager should only be initialized once."
)
self.preemption_sync_manager = xe.create_preemption_sync_manager()
self.preemption_sync_manager.initialize(self.client)
global_state = State()
def initialize(
ip: str,
port: int,
num_processes: int,
process_id: int,
local_device_ids: Optional[Union[int, Sequence[int]]] = None,
):
ip = "127.0.0.1" if ip == "localhost" else ip
if global_state.service == None and global_state.client == None:
global_state.initialize(ip, port, num_processes, process_id, local_device_ids)
atexit.register(shutdown)
else:
assert (
global_state.client != None
), "global_state.client should not be None if server is created"
if global_state.ip == ip and global_state.port == port:
return
else:
msg = (
f"xla distribute server/client have been created on {global_state.ip}:{global_state.port}. "
f"so ignore the request to create on {ip}:{port}"
)
warnings.warn(msg, category=RuntimeWarning)
def shutdown():
global_state.shutdown()
from functools import lru_cache, partial
import numpy as np
from ..tensor import Parameter as MgeParameter
from ..tensor import Tensor as MgeTensor
from .lib import xla_client as xc
_python_scalar_dtype_to_npdtypes = {
bool: np.dtype("bool"),
int: np.dtype("int64"),
float: np.dtype("float64"),
complex: np.dtype("complex128"),
}
_python_scalar_dtypes = list(_python_scalar_dtype_to_npdtypes.keys())
bfloat16 = xc.bfloat16
_bfloat16_dtype = np.dtype(bfloat16)
_float_types = [
_bfloat16_dtype,
np.dtype("float16"),
np.dtype("float32"),
np.dtype("float64"),
]
_numpy_scalar_types = {
np.int8,
np.int16,
np.int32,
np.int64,
np.uint8,
np.uint16,
np.uint32,
np.uint64,
np.complex64,
np.complex128,
np.bool_,
np.longlong,
np.intc,
} | set(np.dtype(dt).type for dt in _float_types)
_np_types = {np.ndarray} | _numpy_scalar_types
_dtype_to_32bit_dtype = {
np.dtype("int64"): np.dtype("int32"),
np.dtype("uint64"): np.dtype("uint32"),
np.dtype("float64"): np.dtype("float32"),
np.dtype("complex128"): np.dtype("complex64"),
}
def _scalar_type_to_dtype(typ, value):
dtype = canonicalize_dtype(_python_scalar_dtype_to_npdtypes[typ])
if typ is int and value is not None:
if value < np.iinfo(dtype).min or value > np.iinfo(dtype).max:
raise OverflowError(f"Python int {value} too large to convert to {dtype}")
return dtype
# do not enable x64 because megengine only support x32
@lru_cache(maxsize=None)
def canonicalize_dtype(dtype, x64_enabled=False, allow_opaque_dtype=False):
assert allow_opaque_dtype == False and x64_enabled == False
try:
dtype_ = np.dtype(dtype)
except TypeError as e:
raise TypeError(f"dtype {dtype!r} not understood") from e
if x64_enabled:
return dtype_
else:
return _dtype_to_32bit_dtype.get(dtype_, dtype_)
def _canonicalize_ndarray_dtype(x):
return np.asarray(x, canonicalize_dtype(x.dtype))
def _canonicalize_python_scalar_dtype(typ, x):
return np.asarray(x, canonicalize_dtype(_scalar_type_to_dtype(typ, x)))
def _canonicalize_mgetensor_dtype(x: MgeTensor):
canonicalized = canonicalize_dtype(x.dtype)
if canonicalized != x.dtype:
return x.astype(canonicalized)
return x
canonicalize_args_handlers = {}
canonicalize_args_handlers.update(
(t, _canonicalize_ndarray_dtype) for t in _numpy_scalar_types
)
canonicalize_args_handlers[np.ndarray] = _canonicalize_ndarray_dtype
canonicalize_args_handlers.update(
(t, partial(_canonicalize_python_scalar_dtype, t)) for t in _python_scalar_dtypes
)
canonicalize_args_handlers[MgeTensor] = _canonicalize_mgetensor_dtype
canonicalize_args_handlers[MgeParameter] = _canonicalize_mgetensor_dtype
def canonicalize_arg(x):
typ = type(x)
handler = canonicalize_args_handlers.get(typ)
if handler:
return handler(x)
raise TypeError(f"No canonicalize_dtype handler for type: {type(x)}")
import io
from abc import ABC, abstractmethod
from functools import partial
from typing import Any, Callable, Dict, Sequence, Tuple
import numpy as np
from ..core._imperative_rt import ops as mops
from ..core._imperative_rt.core2 import OpInfo, VarInfo
from . import dtype
from .lib.mlir import ir
from .lib.mlir.dialects import hlo
func_id = 0
def _default_func_name():
global func_id
func_id += 1
return f"please_realize_func_name_system_{func_id}"
def _is_rng_op(opr):
return isinstance(
opr,
(
mops.Dropout,
mops.BetaRNG,
mops.GammaRNG,
mops.GaussianRNG,
mops.PermutationRNG,
mops.PoissonRNG,
mops.ShuffleRNG,
mops.UniformRNG,
),
)
class AbstractVar:
def __init__(self, _id, _shape, _dtype) -> None:
self.id = _id
self.shape = _shape
self.dtype = _dtype
self.bound_data = None
class Pass(ABC):
def __init__(self) -> None:
pass
@abstractmethod
def __call__(self, tr) -> Any:
pass
# because xla pass key as a tensor, while mge pass key as a param, so we need to add a
# rng key tensor to the graph and set it as the input of the graph and rng op
class RngKeyAdder(Pass):
def __call__(self, tr) -> Any:
has_rng_opr = False
for eqn in tr.eqns:
if _is_rng_op(eqn.op):
has_rng_opr = True
break
if not has_rng_opr:
return tr
# it should be [2, np.uint64], however, megengine donot support np.uint64/np.int64/np.uint32
inp_rng_state_var = AbstractVar(tr.next_vid, [2, 2], np.dtype(np.int32))
tr.add_input(inp_rng_state_var)
new_eqns = []
for eqn in tr.eqns:
if not _is_rng_op(eqn.op):
new_eqns.append(eqn)
continue
oup_rng_state_var = AbstractVar(tr.next_vid, [2, 2], np.dtype(np.int32))
tr.add_var(oup_rng_state_var)
inputs, outputs = list(eqn.inputs), list(eqn.outputs)
inputs.append(inp_rng_state_var.id)
outputs.append(oup_rng_state_var.id)
new_eqn = OpInfo(eqn.op, inputs, outputs, eqn.kind)
new_eqns.append(new_eqn)
inp_rng_state_var = oup_rng_state_var
tr.eqns = new_eqns
tr.set_var_as_oup(inp_rng_state_var)
return tr
# in megengine, dropout return a bit-mask while xla hard to represent, so we let xla
# return a uint8 mask, which means the mask is 8 times larger than mge
class DropoutMaskCanonicalizer(Pass):
def __call__(self, tr) -> Any:
for eqn in tr.eqns:
if not isinstance(eqn.op, mops.Dropout):
continue
inputs, outputs = list(eqn.inputs), list(eqn.outputs)
mask_var = tr.vars[outputs[1]]
inp_shape = tr.vars[inputs[0]].shape
new_mask_var = AbstractVar(
mask_var.id, (int(np.prod(inp_shape)),), mask_var.dtype
)
tr.vars[mask_var.id] = new_mask_var
return tr
class TraceResult:
def __init__(self, traced, func_name=None) -> None:
self.func_name = func_name if func_name is not None else _default_func_name()
self.traced = traced
self.eqns = []
self.vars = {}
self.inputs = []
self.outputs = []
self.consts = []
self.custom_vid = 0
self.effects = []
for var in self.traced.vars:
self.add_var(var)
self.custom_vid = max(var.id + 1, self.custom_vid)
if var.kind == "external" and var.inp_mark:
self.inputs.append(var.id)
if var.data_required:
self.outputs.append(var.id)
if var.kind == "const":
self.consts.append(var.id)
for op in self.traced.ops:
self.eqns.append(op)
@property
def _var_inputs(self):
return [self.vars[i] for i in self.inputs]
@property
def _var_outputs(self):
return [self.vars[i] for i in self.outputs]
@property
def _var_consts(self):
return [self.vars[i] for i in self.consts]
@property
def next_vid(self):
ret = self.custom_vid
self.custom_vid += 1
return ret
def add_var(self, var):
assert var.id not in self.vars
self.vars[var.id] = var
def add_input(self, inp_var):
self.add_var(inp_var)
self.inputs.append(inp_var.id)
def set_var_as_oup(self, oup_var):
assert oup_var.id in self.vars
self.outputs.append(oup_var.id)
def get_var(self, idx):
assert isinstance(idx, int)
return self.vars[idx]
def is_input(self, var):
if isinstance(var, int):
var = self.vars[var]
return var.kind == "external"
def is_output(self, var):
if isinstance(var, int):
var = self.vars[var]
return var.data_required
def _str_var(self, var):
def _str_shape(shp):
return "x".join([str(d) for d in shp])
dtype_to_str = {
"float16": "f16",
"float32": "f32",
"int32": "i32",
"int64": "i64",
"uint8": "u8",
"uint32": "u32",
"uint64": "u64",
"bool": "i1-bool",
}
if isinstance(var, int):
var = self.vars[var]
var_dtype = None
try:
var_dtype = dtype_to_str[str(var.dtype)]
except RuntimeError:
var_dtype = "unknown"
var_bound_data = (
("," + ",".join(str(var.bound_data).split()))
if var.bound_data is not None and var.bound_data.size < 5
else ""
)
return f"{var.id}%:<{_str_shape(var.shape)},{var_dtype}{var_bound_data}>"
def _str_eqn(self, eqn):
inps = ", ".join(map(self._str_var, eqn.inputs))
oups = ", ".join(map(self._str_var, eqn.outputs))
str_op = str(eqn.op)
if isinstance(eqn.op, mops.Reduce):
assert str(eqn.op.mode).startswith("Reduce.Mode.")
str_op = str_op + str(eqn.op.mode)[len("Reduce.Mode.") :]
ret = f"{oups} = {str_op}({inps})"
return ret
def __str__(self) -> str:
func_inps_str = ", ".join(map(self._str_var, self.inputs))
func_oups_str = ", ".join(map(self._str_var, self.outputs))
func_const_str = "\n ".join(map(self._str_var, self.consts))
ret = f"{self.func_name}({func_inps_str}) -> ({func_oups_str}) {{\n "
if len(self.consts) > 0:
ret += f"const:\n {func_const_str}\n "
ret += "\n ".join(map(self._str_eqn, self.eqns))
ret += "\n}"
return ret
_dtype_to_ir_type: Dict[np.dtype, Callable[[], ir.Type]] = {
np.dtype(np.bool_): partial(ir.IntegerType.get_signless, 1),
np.dtype(np.int8): partial(ir.IntegerType.get_signless, 8),
np.dtype(np.int16): partial(ir.IntegerType.get_signless, 16),
np.dtype(np.int32): partial(ir.IntegerType.get_signless, 32),
np.dtype(np.int64): partial(ir.IntegerType.get_signless, 64),
np.dtype(np.uint8): partial(ir.IntegerType.get_unsigned, 8),
np.dtype(np.uint16): partial(ir.IntegerType.get_unsigned, 16),
np.dtype(np.uint32): partial(ir.IntegerType.get_unsigned, 32),
np.dtype(np.uint64): partial(ir.IntegerType.get_unsigned, 64),
np.dtype(dtype.bfloat16): ir.BF16Type.get,
np.dtype(np.float16): ir.F16Type.get,
np.dtype(np.float32): ir.F32Type.get,
np.dtype(np.float64): ir.F64Type.get,
np.dtype(np.complex64): lambda: ir.ComplexType.get(ir.F32Type.get()),
np.dtype(np.complex128): lambda: ir.ComplexType.get(ir.F64Type.get()),
}
def mge_dtype_to_ir_type(mge_dtype):
mge_dtype = np.dtype(mge_dtype)
assert isinstance(
mge_dtype, np.dtype
), f"arg should be numpy dtype, but is {mge_dtype}"
ir_type_factory = _dtype_to_ir_type[mge_dtype]
return ir_type_factory()
def mge_varinfo_to_ir_type(mge_varinfo):
assert isinstance(mge_varinfo, (VarInfo, AbstractVar)), "args should be VarInfo"
shape = mge_varinfo.shape
return ir.RankedTensorType.get(shape, mge_dtype_to_ir_type(mge_varinfo.dtype))
def mge_varinfo_to_ir_type_tuple(mge_varinfo):
return (mge_varinfo_to_ir_type(mge_varinfo),)
def make_ir_type_according_meta(src_shape: Tuple, src_dtype: np.dtype):
return ir.RankedTensorType.get(src_shape, mge_dtype_to_ir_type(src_dtype))
def make_ir_type_according_meta_tuple(src_shape: Tuple, src_dtype: np.dtype):
return (make_ir_type_according_meta(src_shape, src_dtype),)
_constant_handlers = {}
def _numpy_array_constant(x: np.ndarray, canonicalize_types) -> Sequence[ir.Value]:
if canonicalize_types:
x = np.asarray(x, dtype.canonicalize_dtype(x.dtype))
element_type = mge_dtype_to_ir_type(x.dtype)
shape = x.shape
if x.dtype == np.bool_:
nelems = x.size
x = np.packbits(x, bitorder="little")
if nelems == 1:
x = np.array(0 if x.item() == 0 else 0xFF, np.uint8)
elif x.dtype == dtype.bfloat16:
x = x.view(np.uint16)
x = np.ascontiguousarray(x)
attr = ir.DenseElementsAttr.get(x, type=element_type, shape=shape)
return (hlo.ConstantOp(attr).result,)
def _ndarray_constant_handler(
val: np.ndarray, canonicalize_types
) -> Sequence[ir.Value]:
if np.any(np.equal(0, val.strides)) and val.size > 0:
(zero_stride_axes,) = np.where(np.equal(0, val.strides))
(other_axes,) = np.where(np.not_equal(0, val.strides))
collapsed_val = val[
tuple(
0 if ax in zero_stride_axes else slice(None) for ax in range(val.ndim)
)
]
if canonicalize_types:
collapsed_val = np.asarray(
collapsed_val, dtype.canonicalize_dtype(collapsed_val.dtype)
)
out = hlo.BroadcastInDimOp(
ir.RankedTensorType.get(
val.shape, mge_dtype_to_ir_type(collapsed_val.dtype)
),
_numpy_array_constant(collapsed_val, canonicalize_types=False)[0],
dense_int_elements(other_axes),
).result
return (out,)
else:
return _numpy_array_constant(val, canonicalize_types)
_constant_handlers[np.ndarray] = _ndarray_constant_handler
for _scalar_type in [
np.int8,
np.int16,
np.int32,
np.int64,
np.uint8,
np.uint16,
np.uint32,
np.uint64,
np.float16,
np.float32,
np.float64,
np.complex64,
np.complex128,
np.bool_,
np.longlong,
dtype.bfloat16,
]:
_constant_handlers[_scalar_type] = _ndarray_constant_handler
def _python_scalar_constant_handler(dtype, val, canonicalize_dtypes):
return _numpy_array_constant(np.array(val, dtype), canonicalize_dtypes)
for pt, dt in dtype._python_scalar_dtype_to_npdtypes.items():
_constant_handlers[pt] = partial(_python_scalar_constant_handler, dt)
def _mge_varinfo_constant_handler(val, canonicalize_dtypes):
assert isinstance(val, VarInfo)
assert val.bound_data is not None and val.kind == "const"
assert isinstance(val.bound_data, np.ndarray)
return _numpy_array_constant(
np.asarray(val.bound_data, val.dtype), canonicalize_dtypes
)
_constant_handlers[VarInfo] = _mge_varinfo_constant_handler
def ir_constant_tuple(val: Any, canonicalize_types: bool = True) -> Sequence[ir.Value]:
for t in type(val).__mro__:
handler = _constant_handlers.get(t)
if handler:
out = handler(val, canonicalize_types)
assert all(isinstance(v, ir.Value) for v in out), (type(val), out)
return out
assert False
def ir_constant(val: Any, canonicalize_types: bool = True) -> Sequence[ir.Value]:
values = ir_constant_tuple(val, canonicalize_types=canonicalize_types)
assert len(values) == 1
return values[0]
def token_type() -> Sequence[ir.Type]:
return [hlo.TokenType.get()]
def dummy_token_type_tuple() -> Sequence[ir.Type]:
return make_ir_type_according_meta_tuple((0,), np.bool_)
def dummy_token() -> Sequence[ir.Value]:
return ir_constant_tuple(np.zeros(0, np.bool_))
def i32_attr(i):
return ir.IntegerAttr.get(ir.IntegerType.get_signless(32), i)
def i64_attr(i):
return ir.IntegerAttr.get(ir.IntegerType.get_signless(64), i)
def ui64_attr(i):
return ir.IntegerAttr.get(ir.IntegerType.get_unsigned(64), i)
def f32_attr(i):
return ir.FloatAttr.get(ir.F32Type.get(), i)
def precision_attr(lhs_prec, rhs_prec) -> ir.ArrayAttr:
lhs_prec = str(lhs_prec)
rhs_prec = str(rhs_prec)
assert lhs_prec == "float32"
assert rhs_prec == "float32"
dtype_to_precision = {
"float32": "DEFAULT",
}
precision = (dtype_to_precision[lhs_prec], dtype_to_precision[rhs_prec])
return ir.ArrayAttr.get([hlo.PrecisionAttr.get(p) for p in precision])
def dense_int_elements(xs) -> ir.DenseIntElementsAttr:
return ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64))
def dense_bool_elements(xs: Sequence[bool]) -> ir.DenseElementsAttr:
a = np.packbits(np.array(xs, np.bool_), bitorder="little")
if len(xs) == 1:
a = np.array(0 if a.item() == 0 else 0xFF, np.uint8)
return ir.DenseElementsAttr.get(
a, type=ir.IntegerType.get_signless(1), shape=[len(xs)]
)
def get_irnode_shape(irnode):
if isinstance(irnode, (list, tuple, ir.OpResultList)):
assert len(irnode) == 1
irnode = irnode[0]
assert isinstance(irnode, (ir.RankedTensorType, ir.BlockArgument, ir.OpResult))
if not isinstance(irnode, ir.RankedTensorType):
irnode = ir.RankedTensorType(irnode.type)
return tuple(irnode.shape)
def get_irnode_dtype(irnode):
if isinstance(irnode, (list, tuple, ir.OpResultList)):
assert len(irnode) == 1
irnode = irnode[0]
assert isinstance(
irnode, (ir.RankedTensorType, ir.BlockArgument, ir.OpResult)
), type(irnode)
if not isinstance(irnode, ir.RankedTensorType):
irnode = ir.RankedTensorType(irnode.type)
etype = irnode.element_type
for k, v in _dtype_to_ir_type.items():
if etype == v():
return k
assert False, f"unknown irnode {irnode}"
def module_to_string(module: ir.Module) -> str:
output = io.StringIO()
module.operation.print(
file=output, enable_debug_info=True, print_generic_op_form=False
)
return output.getvalue()
def module_to_bytecode(module: ir.Module) -> bytes:
output = io.BytesIO()
module.operation.write_bytecode(file=output)
return output.getvalue()
# code of this directory is mainly from jax: https://github.com/google/jax
try:
import mge_xlalib as mge_xlalib
except ModuleNotFoundError as err:
msg = (
"mge-xla requires mge_xlalib to be installed. if this problem happened when "
"pytest, maybe you have set --doctest-modules for pytest. you can close it "
"by setup `norecursedirs = megengine/xla` in `pytest.ini`"
)
raise ModuleNotFoundError(msg)
import gc
import os
import platform
import subprocess
import sys
import warnings
from typing import Optional
import mge_xlalib.cpu_feature_guard as cpu_feature_guard
import mge_xlalib.ducc_fft as ducc_fft
import mge_xlalib.gpu_linalg as gpu_linalg
import mge_xlalib.gpu_prng as gpu_prng
import mge_xlalib.gpu_rnn as gpu_rnn
import mge_xlalib.gpu_solver as gpu_solver
import mge_xlalib.gpu_sparse as gpu_sparse
import mge_xlalib.lapack as lapack
import mge_xlalib.xla_client as xla_client
from ...core._imperative_rt.common import get_cudnn_version as _get_cudnn_version
if int(platform.python_version_tuple()[1]) < 8:
raise RuntimeError(
f"xla backend requires Python version >= 3.8, got {platform.python_version()}"
)
if _get_cudnn_version() < 8600:
warnings.warn(
f"xla backend can get the max speed up with CUDNN version >= 8.6.0, "
f"but current cudnn version is {_get_cudnn_version()}"
)
cpu_feature_guard.check_cpu_features()
xla_extension = xla_client._xla
pytree = xla_client._xla.pytree
# we use some api in jaxlib
xla_jit = xla_client._xla.jax_jit
pmap_lib = xla_client._xla.pmap_lib
def _xla_gc_callback(*args):
xla_client._xla.collect_garbage()
gc.callbacks.append(_xla_gc_callback)
xla_extension_version: int = getattr(xla_client, "_version", 0)
mlir_api_version = xla_client.mlir_api_version
# Finds the CUDA install path
def _find_cuda_root_dir() -> Optional[str]:
cuda_root_dir = os.environ.get("CUDA_ROOT_DIR")
if cuda_root_dir is None:
try:
which = "where" if sys.platform == "win32" else "which"
with open(os.devnull, "w") as devnull:
nvcc = (
subprocess.check_output([which, "nvcc"], stderr=devnull)
.decode()
.rstrip("\r\n")
)
cuda_root_dir = os.path.dirname(os.path.dirname(nvcc))
except Exception:
if sys.platform == "win32":
assert False, "xla not supported on windows"
else:
cuda_root_dir = "/usr/local/cuda"
if not os.path.exists(cuda_root_dir):
cuda_root_dir = None
return cuda_root_dir
cuda_path = _find_cuda_root_dir()
transfer_guard_lib = xla_client._xla.transfer_guard_lib
此差异已折叠。
import mge_xlalib.mlir.dialects.builtin as builtin
import mge_xlalib.mlir.dialects.chlo as chlo
import mge_xlalib.mlir.dialects.func as func
import mge_xlalib.mlir.dialects.mhlo as mhlo
import mge_xlalib.mlir.dialects.ml_program as ml_program
import mge_xlalib.mlir.dialects.sparse_tensor as sparse_tensor
import mge_xlalib.mlir.dialects.stablehlo as stablehlo
hlo = stablehlo
# code of this file is mainly from jax: https://github.com/google/jax
import logging
import os
import platform as py_platform
import threading
import warnings
from functools import lru_cache, partial
from typing import Any, Dict, List, Optional, Union
import numpy as np
from mge_xlalib import xla_client
from ..lib import cuda_path
from .config import bool_env, config, flags, int_env
XlaBackend = xla_client._xla.Client
ShardedBuffer = Any
FLAGS = flags.FLAGS
logger = logging.getLogger(__name__)
flags.DEFINE_string(
"xla_backend", "", "Deprecated, please use --xla_platforms instead."
)
flags.DEFINE_string(
"xla_backend_target",
os.getenv("XLA_BACKEND_TARGET", "").lower(),
'Either "local" or "rpc:address" to connect to a remote service target.',
)
flags.DEFINE_string(
"xla_platform_name",
os.getenv("XLA_PLATFORM_NAME", "").lower(),
"Deprecated, please use --xla_platforms instead.",
)
flags.DEFINE_bool(
"xla_disable_most_optimizations",
bool_env("XLA_DISABLE_MOST_OPTIMIZATIONS", False),
"Try not to do much optimization work. This can be useful if the cost of "
"optimization is greater than that of running a less-optimized program.",
)
flags.DEFINE_integer(
"xla_profile_version",
int_env("XLA_PROFILE_VERSION", 0),
"Optional profile version for XLA compilation. "
"This is meaningful only when XLA is configured to "
"support the remote compilation profile feature.",
)
flags.DEFINE_string(
"xla_cuda_visible_devices",
"all",
'Restricts the set of CUDA devices that XLA will use. Either "all", or a '
"comma-separate list of integer device IDs.",
)
flags.DEFINE_string(
"xla_rocm_visible_devices",
"all",
'Restricts the set of ROCM devices that XLA will use. Either "all", or a '
"comma-separate list of integer device IDs.",
)
def get_compile_options(
num_replicas: int,
num_partitions: int,
device_assignment=None,
use_spmd_partitioning: bool = True,
use_auto_spmd_partitioning: bool = False,
auto_spmd_partitioning_mesh_shape=[],
auto_spmd_partitioning_mesh_ids=[],
) -> xla_client.CompileOptions:
"""Returns the compile options to use, as derived from flag values.
Args:
num_replicas: Number of replicas for which to compile.
num_partitions: Number of partitions for which to compile.
device_assignment: Optional ndarray of xla devices indicating the assignment
of logical replicas to physical devices (default inherited from
xla_client.CompileOptions). Must be consistent with `num_replicas` and
`num_partitions`.
use_spmd_partitioning: boolean indicating whether to enable SPMD or MPMD
partitioning in XLA.
use_auto_spmd_partitioning: boolean indicating whether to automatically
generate XLA shardings for SPMD partitioner.
auto_spmd_partitioning_mesh_shape: device mesh shape used to create
auto_spmd_partitioning search space.
auto_spmd_partitioning_mesh_ids: device ids used to create
auto_spmd_partitioning search space.
"""
compile_options = xla_client.CompileOptions()
compile_options.num_replicas = num_replicas
compile_options.num_partitions = num_partitions
build_options = compile_options.executable_build_options
build_options.use_spmd_partitioning = use_spmd_partitioning
build_options.use_auto_spmd_partitioning = use_auto_spmd_partitioning
if use_auto_spmd_partitioning:
build_options.auto_spmd_partitioning_mesh_shape = (
auto_spmd_partitioning_mesh_shape
)
build_options.auto_spmd_partitioning_mesh_ids = auto_spmd_partitioning_mesh_ids
if device_assignment is not None:
logger.debug(
"get_compile_options: num_replicas=%s num_partitions=%s device_assignment=%s",
num_replicas,
num_partitions,
device_assignment,
)
device_assignment = np.array(device_assignment)
# Allow 1D device assignment if num_partitions is 1.
if (device_assignment.ndim == 1) and (num_partitions == 1):
device_assignment = device_assignment[:, None]
if num_replicas != device_assignment.shape[0]:
msg = "device_assignment does not match num_replicas: {} vs {}."
raise ValueError(msg.format(device_assignment, num_replicas))
if num_partitions != device_assignment.shape[1]:
msg = "device_assignment does not match num_partitions: {} vs {}."
raise ValueError(msg.format(device_assignment, num_partitions))
if device_assignment.dtype == object:
device_assignment = np.vectorize(lambda d: d.id, otypes=[int])(
device_assignment
)
device_assignment = xla_client.DeviceAssignment.create(device_assignment)
assert device_assignment.replica_count() == num_replicas
assert device_assignment.computation_count() == num_partitions
compile_options.device_assignment = device_assignment
debug_options = compile_options.executable_build_options.debug_options
if cuda_path is not None:
debug_options.xla_gpu_cuda_data_dir = cuda_path
if FLAGS.xla_disable_most_optimizations:
debug_options.xla_backend_optimization_level = 0
debug_options.xla_llvm_disable_expensive_passes = True
debug_options.xla_test_all_input_layouts = False
compile_options.profile_version = FLAGS.xla_profile_version
return compile_options
# Backends, in increasing order of preference.
# We have no particular opinion about how "backends" relate to "devices". For
# example, there could be multiple backends that provide the same kind of
# device.
_backend_factories = {}
_default_backend = None
_backends: Dict[str, Any] = {}
_backends_errors: Dict[str, str] = {}
_backend_lock = threading.Lock()
def register_backend_factory(name, factory, *, priority=0):
with _backend_lock:
if name in _backends:
raise RuntimeError(f"Backend {name} already initialized")
_backend_factories[name] = (factory, priority)
register_backend_factory(
"interpreter", xla_client.make_interpreter_client, priority=-100
)
register_backend_factory(
"cpu", partial(xla_client.make_cpu_client, use_tfrt=True), priority=0
)
def make_gpu_client(*, platform_name, visible_devices_flag):
from ..distribute import global_state
visible_devices = global_state.visible_devices
if visible_devices != "all":
allowed_devices = {int(x) for x in visible_devices.split(",")}
else:
allowed_devices = None
return xla_client.make_gpu_client(
distributed_client=global_state.client,
node_id=global_state.process_id,
platform_name=platform_name,
allowed_devices=allowed_devices,
)
if hasattr(xla_client, "make_gpu_client"):
register_backend_factory(
"cuda",
partial(
make_gpu_client,
platform_name="cuda",
visible_devices_flag="xla_cuda_visible_devices",
),
priority=200,
)
register_backend_factory(
"rocm",
partial(
make_gpu_client,
platform_name="rocm",
visible_devices_flag="xla_rocm_visible_devices",
),
priority=200,
)
if hasattr(xla_client, "make_plugin_device_client"):
# It is assumed that if xla has been built with a plugin client, then the
# user wants to use the plugin client by default. Therefore, it gets the
# highest priority.
register_backend_factory(
"plugin", xla_client.make_plugin_device_client, priority=400
)
_platform_aliases = {
"cuda": "gpu",
"rocm": "gpu",
}
_alias_to_platforms: Dict[str, List[str]] = {}
for _platform, _alias in _platform_aliases.items():
_alias_to_platforms.setdefault(_alias, []).append(_platform)
def is_known_platform(platform: str):
# A platform is valid if there is a registered factory for it. It does not
# matter if we were unable to initialize that platform; we only care that
# we've heard of it and it isn't, e.g., a typo.
return platform in _backend_factories.keys() or platform in _platform_aliases.keys()
def canonicalize_platform(platform: str) -> str:
"""Replaces platform aliases with their concrete equivalent.
In particular, replaces "gpu" with either "cuda" or "rocm", depending on which
hardware is actually present. We want to distinguish "cuda" and "rocm" for
purposes such as MLIR lowering rules, but in many cases we don't want to
force users to care.
"""
platforms = _alias_to_platforms.get(platform, None)
if platforms is None:
return platform
b = backends()
for p in platforms:
if p in b.keys():
return p
raise RuntimeError(
f"Unknown backend: '{platform}' requested, but no "
f"platforms that are instances of {platform} are present. "
"Platforms present are: " + ",".join(b.keys())
)
def expand_platform_alias(platform: str) -> List[str]:
"""Expands, e.g., "gpu" to ["cuda", "rocm"].
This is used for convenience reasons: we expect cuda and rocm to act similarly
in many respects since they share most of the same code.
"""
return _alias_to_platforms.get(platform, [platform])
def is_gpu(platform):
return platform in ("cuda", "rocm")
def backends():
global _backends
global _backends_errors
global _default_backend
with _backend_lock:
if _backends:
return _backends
if config.xla_platforms:
xla_platforms = config.xla_platforms.split(",")
platforms = []
# Allow platform aliases in the list of platforms.
for platform in xla_platforms:
platforms.extend(expand_platform_alias(platform))
priorities = range(len(platforms), 0, -1)
platforms_and_priorites = zip(platforms, priorities)
else:
platforms_and_priorites = (
(platform, priority)
for platform, (_, priority) in _backend_factories.items()
)
default_priority = -1000
if hasattr(xla_client, "maybe_load_pjrt_plugins"):
xla_client.maybe_load_pjrt_plugins()
for platform, priority in platforms_and_priorites:
try:
backend = _init_backend(platform)
_backends[platform] = backend
if priority > default_priority:
_default_backend = backend
default_priority = priority
except Exception as err:
if platform in ("cpu", "interpreter"):
# We always expect the CPU and interpreter backends to initialize
# successfully.
raise
else:
# If the backend isn't built into the binary, or if it has no devices,
# we expect a RuntimeError.
err_msg = f"Unable to initialize backend '{platform}': {err}"
if config.xla_platforms:
err_msg += " (set XLA_PLATFORMS='' to automatically choose an available backend)"
raise RuntimeError(err_msg)
else:
_backends_errors[platform] = str(err)
logger.info(err_msg)
continue
# We don't warn about falling back to CPU on Mac OS, because we don't
# support anything else there at the moment and warning would be pointless.
if (
py_platform.system() != "Darwin"
and _default_backend.platform == "cpu"
and FLAGS.xla_platform_name != "cpu"
):
logger.warning("No GPU/TPU found, falling back to CPU. ")
return _backends
def _clear_backends():
global _backends
global _backends_errors
global _default_backend
logger.info("Clearing XLA backend caches.")
with _backend_lock:
_backends = {}
_backends_errors = {}
_default_backend = None
get_backend.cache_clear()
def _init_backend(platform):
factory, unused_priority = _backend_factories.get(platform, (None, None))
if factory is None:
raise RuntimeError(f"Unknown backend '{platform}'")
logger.debug("Initializing backend '%s'", platform)
backend = factory()
# TODO: consider raising more descriptive errors directly from backend
# factories instead of returning None.
if backend is None:
raise RuntimeError(f"Could not initialize backend '{platform}'")
if backend.device_count() == 0:
raise RuntimeError(f"Backend '{platform}' provides no devices.")
logger.debug("Backend '%s' initialized", platform)
return backend
def _get_backend_uncached(platform=None):
# TODO: remove this input polymorphism after we clean up how
# 'backend' values are handled
if not isinstance(platform, (type(None), str)):
return platform
platform = platform or FLAGS.xla_backend or FLAGS.xla_platform_name or None
bs = backends()
if platform is not None:
platform = canonicalize_platform(platform)
backend = bs.get(platform, None)
if backend is None:
if platform in _backends_errors:
raise RuntimeError(
f"Backend '{platform}' failed to initialize: "
f"{_backends_errors[platform]}"
)
raise RuntimeError(f"Unknown backend {platform}")
return backend
else:
return _default_backend
@lru_cache(maxsize=None) # don't use util.memoize because there is no X64 dependence.
def get_backend(platform=None):
return _get_backend_uncached(platform)
def get_device_backend(device=None):
"""Returns the Backend associated with `device`, or the default Backend."""
if device is not None:
return device.client
return get_backend()
def device_count(backend: Optional[Union[str, XlaBackend]] = None) -> int:
"""Returns the total number of devices.
On most platforms, this is the same as :py:func:`xla.local_device_count`.
However, on multi-process platforms where different devices are associated
with different processes, this will return the total number of devices across
all processes.
Args:
backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
``'tpu'``.
Returns:
Number of devices.
"""
return int(get_backend(backend).device_count())
def local_device_count(backend: Optional[Union[str, XlaBackend]] = None) -> int:
"""Returns the number of devices addressable by this process."""
return int(get_backend(backend).local_device_count())
def devices(
backend: Optional[Union[str, XlaBackend]] = None
) -> List[xla_client.Device]:
"""Returns a list of all devices for a given backend.
Each device is represented by a subclass of :class:`Device` (e.g.
:class:`CpuDevice`, :class:`GpuDevice`). The length of the returned list is
equal to ``device_count(backend)``. Local devices can be identified by
comparing :attr:`Device.process_index` to the value returned by
:py:func:`xla.process_index`.
If ``backend`` is ``None``, returns all the devices from the default backend.
The default backend is generally ``'gpu'`` or ``'tpu'`` if available,
otherwise ``'cpu'``.
Args:
backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
``'tpu'``.
Returns:
List of Device subclasses.
"""
return get_backend(backend).devices()
def default_backend() -> str:
"""Returns the platform name of the default XLA backend."""
return get_backend(None).platform
def local_devices(
process_index: Optional[int] = None,
backend: Optional[Union[str, XlaBackend]] = None,
host_id: Optional[int] = None,
) -> List[xla_client.Device]:
"""Like :py:func:`xla.devices`, but only returns devices local to a given process.
If ``process_index`` is ``None``, returns devices local to this process.
Args:
process_index: the integer index of the process. Process indices can be
retrieved via ``len(xla.process_count())``.
backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
``'tpu'``.
Returns:
List of Device subclasses.
"""
if host_id is not None:
warnings.warn(
"The argument to xla.local_devices has been renamed from `host_id` to "
"`process_index`. This alias will eventually be removed; please update "
"your code."
)
process_index = host_id
if process_index is None:
process_index = get_backend(backend).process_index()
if not (0 <= process_index < process_count()):
raise ValueError(f"Unknown process_index {process_index}")
return [d for d in devices(backend) if d.process_index == process_index]
def process_index(backend: Optional[Union[str, XlaBackend]] = None) -> int:
"""Returns the integer process index of this process.
On most platforms, this will always be 0. This will vary on multi-process
platforms though.
Args:
backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
``'tpu'``.
Returns:
Integer process index.
"""
return get_backend(backend).process_index()
def host_id(backend=None):
warnings.warn(
"xla.host_id has been renamed to xla.process_index. This alias "
"will eventually be removed; please update your code."
)
return process_index(backend)
def process_count(backend: Optional[Union[str, XlaBackend]] = None) -> int:
"""Returns the number of XLA processes associated with the backend."""
return max(d.process_index for d in devices(backend)) + 1
def host_count(backend=None):
warnings.warn(
"xla.host_count has been renamed to xla.process_count. This alias "
"will eventually be removed; please update your code."
)
return process_count(backend)
def host_ids(backend=None):
warnings.warn(
"xla.host_ids has been deprecated; please use range(xla.process_count()) "
"instead. xla.host_ids will eventually be removed; please update your "
"code."
)
return list(range(process_count(backend)))
import dataclasses
from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union
import numpy as np
from ..core._imperative_rt.core2 import OpInfo, VarInfo
from . import utils
from .device import xb
from .ir_utils import TraceResult, ir_constant_tuple, mge_varinfo_to_ir_type_tuple
from .lib import xla_client as xc
from .lib.mlir import dialects, ir
from .lib.mlir.dialects import func as func_dialect
from .rules import get_rule
from .rules.hlotensor import HLOTensor
from .rules.utils import _shape_equal
from .sharding import sharded_val
def make_ir_context() -> ir.Context:
context = ir.Context()
dialects.mhlo.register_mhlo_dialect(context)
dialects.chlo.register_dialect(context)
dialects.stablehlo.register_dialect(context)
return context
@dataclasses.dataclass
class ModuleContext:
context: ir.Context
module: ir.Module
ip: ir.InsertionPoint
symbol_table: ir.SymbolTable
backend_or_name: Optional[Union[str, xb.XlaBackend]]
platform: str
keepalives: List[Any]
channel_iterator: Iterator[int]
host_callbacks: List[Any]
# Stores the value of varinfo that can be inferred in lowering process
inferred_values: Dict[VarInfo, np.ndarray]
def __init__(
self,
backend_or_name: Optional[Union[str, xb.XlaBackend]],
platform: str,
keepalives: List[Any] = [],
host_callbacks: List[Any] = [],
context: Optional[ir.Context] = None,
module: Optional[ir.Module] = None,
ip: Optional[ir.InsertionPoint] = None,
symbol_table: Optional[ir.SymbolTable] = None,
):
assert platform is not None
self.context = context or make_ir_context()
self.module = module or ir.Module.create(loc=ir.Location.unknown(self.context))
self.ip = ip or ir.InsertionPoint(self.module.body)
self.symbol_table = symbol_table or ir.SymbolTable(self.module.operation)
self.backend_or_name = backend_or_name
self.platform = platform
self.keepalives = keepalives
self.host_callbacks = host_callbacks
self.inferred_values = {}
@property
def backend(self) -> xb.XlaBackend:
if self.backend_or_name is None or isinstance(self.backend_or_name, str):
return xb.get_backend(self.backend_or_name)
return self.backend_or_name
def replace(self, **kw):
return dataclasses.replace(self, **kw)
def get_value(self, varinfo):
assert varinfo in self.inferred_values
return self.inferred_values[varinfo]
def set_value(self, varinfo, value):
self.inferred_values[varinfo] = value
@dataclasses.dataclass
class LoweringRuleContext:
module_context: ModuleContext
op: OpInfo
vars_in: Sequence[VarInfo]
vars_out: Sequence[VarInfo]
param: Dict = None
def replace(self, **kw):
return dataclasses.replace(self, **kw)
def _unwrap_singleton_ir_values(x):
return x[0] if len(x) == 1 else x
def _wrap_singleton_ir_values(
x: Union[ir.Value, Sequence[ir.Value]]
) -> Sequence[ir.Value]:
return (x,) if isinstance(x, ir.Value) else tuple(x)
def lowering_ops(
ctx: ModuleContext, trace_result: TraceResult, *args: Sequence[ir.Value],
):
# var_id -> ir.Value
env: Dict[int, Tuple[ir.Value, ...]] = {}
consts = list(map(ir_constant_tuple, trace_result._var_consts))
# read ir.Values from env according to var_ids
def read(var_ids):
assert isinstance(var_ids, (list, tuple))
ret = []
for vid in var_ids:
assert isinstance(vid, int)
ret.append(env[vid])
return ret
# update env with var_ids and ir.Values
def write(var_ids, hlo_nodes):
assert isinstance(var_ids, (list, tuple))
assert isinstance(hlo_nodes, (map, list, tuple))
hlo_nodes = list(hlo_nodes)
assert len(var_ids) == len(hlo_nodes), (len(var_ids), len(hlo_nodes))
for vid, node in zip(var_ids, hlo_nodes):
assert vid not in env
env[vid] = node
assert len(args) == len(trace_result.inputs)
assert len(consts) == len(trace_result.consts)
assert all(isinstance(v, ir.Value) for vs in consts for v in vs)
# initialize env with inputs and consts
write(trace_result.inputs, args)
write(trace_result.consts, consts)
for eqn in trace_result.eqns:
rule_ctx = LoweringRuleContext(
module_context=ctx,
op=eqn.op,
vars_in=[trace_result.vars[inp] for inp in eqn.inputs],
vars_out=[trace_result.vars[oup] for oup in eqn.outputs],
param=eqn.param,
)
rule = get_rule(eqn.op, use_fake_rule_for_debug=False)
in_nodes = read(eqn.inputs)
hinps = [
HLOTensor(irval, var.shape, var.dtype)
for var, irval in zip(
rule_ctx.vars_in, map(_unwrap_singleton_ir_values, in_nodes)
)
]
houps = rule(rule_ctx, *hinps)
if isinstance(houps, HLOTensor):
houps = [houps]
out_nodes = []
for out_id, hlo_out in zip(eqn.outputs, houps):
var_out = trace_result.vars[out_id]
assert _shape_equal(
var_out.shape, hlo_out.shape
), f"{eqn.op}: {var_out.shape} != {hlo_out.shape}"
out_nodes.append(hlo_out.tensor)
out_nodes = tuple(map(_wrap_singleton_ir_values, out_nodes))
write(eqn.outputs, out_nodes)
return read(trace_result.outputs)
def make_xla_graph(
ctx: ModuleContext,
name: str,
trace_result: TraceResult,
public: bool = True,
in_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None,
out_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None,
input_output_aliases: Optional[Sequence[Optional[int]]] = None,
) -> func_dialect.FuncOp:
assert public is True, "do not process the visibitity of function"
assert (
in_shardings is None and out_shardings is None
), "sharding when lowering is not supported yet"
assert (
input_output_aliases is None or input_output_aliases == []
), "donated inputs are not supported yet"
input_types = [
mge_varinfo_to_ir_type_tuple(trace_result.vars[idx])
for idx in trace_result.inputs
]
output_types = [
mge_varinfo_to_ir_type_tuple(trace_result.vars[idx])
for idx in trace_result.outputs
]
flat_input_types = utils.flatten_list(input_types)
flat_output_types = utils.flatten_list(output_types)
assert len(flat_input_types) == len(trace_result.inputs)
assert len(flat_output_types) == len(trace_result.outputs)
ftype = ir.FunctionType.get(flat_input_types, flat_output_types)
func_op = func_dialect.FuncOp(name, ftype, ip=ctx.ip)
func_op.attributes["sym_visibility"] = ir.StringAttr.get(
"public" if public else "private"
)
ctx.symbol_table.insert(func_op)
entry_block = func_op.add_entry_block()
with ir.InsertionPoint(entry_block):
flat_args = entry_block.arguments
unflattened_args = utils.unflatten_list(flat_args, map(len, input_types))
outs = lowering_ops(ctx, trace_result, *unflattened_args)
flat_oups = utils.flatten_list(outs)
func_dialect.ReturnOp(flat_oups)
return func_op
def lower(
trace_result: TraceResult,
backend,
platform,
in_shardings=None,
out_shardings=None,
donated_invars=None,
):
assert donated_invars is None, "donated inputs are not supported yet"
assert trace_result.effects == [], "effect of trace is not supported"
if in_shardings is not None:
trace_result.inputs = [
sharded_val(inp, in_sharding)
for inp, in_sharding in zip(trace_result.inputs, in_shardings)
]
if out_shardings is not None:
trace_result.outputs = [
sharded_val(outp, out_sharding)
for outp, out_sharding in zip(trace_result.outputs, out_shardings)
]
ctx = ModuleContext(backend, platform)
with ctx.context, ir.Location.unknown(ctx.context):
module_name = trace_result.func_name
ctx.module.operation.attributes["sym_name"] = ir.StringAttr.get(module_name)
assert trace_result.effects == [], "effect of trace is not supported"
make_xla_graph(
ctx,
"main",
trace_result,
public=True,
in_shardings=None,
out_shardings=None,
input_output_aliases=[],
)
return ctx.module, ctx.keepalives, ctx.host_callbacks
from . import (
communicate,
elemwise,
indexing,
math,
nn,
normalize,
random,
reduction,
tensor,
trivial,
)
from .utils import get_rule
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
from typing import Sequence, Union
import numpy as np
from ...core._imperative_rt import ops as mops
from ..lib.mlir import ir
from .hlotensor import HLOTensor
from .tensor import fill
from .utils import _check_shape, register_lower_rule
@register_lower_rule(mops.GetVarShape)
def get_var_shape_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
if len(args) > 1:
assert len(args) == 2, f"{len(args)}"
_check_shape(args[0].shape, args[1].shape)
shp = args[0].shape
if ctx.op.axis != 7:
shp = (shp[ctx.op.axis],)
shp = np.array(shp, np.int64)
ctx.module_context.set_value(ctx.vars_out[0], shp)
return HLOTensor(shp)
@register_lower_rule("create_tensor")
def create_tensor_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
assert len(args) == len(ctx.vars_in) == len(ctx.vars_out) == 1
var_in, var_out = ctx.vars_in[0], ctx.vars_out[0]
if var_in.bound_data is not None:
ctx.module_context.set_value(var_in, var_in.bound_data)
ctx.module_context.set_value(var_out, var_in.bound_data)
assert var_in.shape == var_out.shape
if var_out.bound_data is not None:
data = np.asarray(var_out.bound_data, var_out.dtype)
elif var_in.bound_data is not None:
data = np.asarray(var_in.bound_data, var_out.dtype)
else:
assert False, "only support create tensor from const now"
return HLOTensor(data)
@register_lower_rule("io_mark_var")
def io_mark_var_lower(ctx, *args: Union[ir.Value, Sequence[ir.Value]]):
assert len(args) == 1
return args
@register_lower_rule("rename")
def rename_lower(ctx, *args: Union[ir.Value, Sequence[ir.Value]]):
assert len(args) == 1
return args
@register_lower_rule("fake_op_rule_for_debug")
def fake_op_lower(ctx, *args: Union[ir.Value, Sequence[ir.Value]]):
return [fill(0.0, out.shape, out.dtype) for out in ctx.vars_out]
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。