提交 582dd4ce 编写于 作者: M Megvii Engine Team

fix(dnn/sfotmax): call cpu dispatch for softmax opr

GitOrigin-RevId: a606e66101614a4bf1135d047a163bd54ad7a650
上级 2976f60b
......@@ -6,35 +6,19 @@
namespace megdnn {
namespace fallback {
void SoftmaxForwardImpl::exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) {
auto axis = param().axis;
if (axis < 0)
axis += src.layout.ndim;
megdnn_assert(axis >= 0);
check_exec(src.layout, dst.layout, workspace.size);
if (!usable(src.layout)) {
naive::SoftmaxForwardImpl::exec(src, dst, workspace);
return;
}
typedef DTypeTrait<dtype::Float32>::ctype Float32;
auto sptr = src.ptr<Float32>();
auto dptr = dst.ptr<Float32>();
constexpr auto float_min = std::numeric_limits<Float32>::min();
constexpr auto step = GI_SIMD_LEN_BYTE / sizeof(Float32);
size_t A, B, C;
reduce::get_ABC(src.layout, A, B, C, axis);
static void do_softmax(
const float* sptr, float* dptr, size_t A, size_t B, size_t C,
_megdnn_workspace workspace) {
constexpr auto float_min = std::numeric_limits<float>::min();
constexpr auto step = GI_SIMD_LEN_BYTE / sizeof(float);
// TODO: When C=2,3,4..., src_ptr span is relatively large, the performance may
// be poor
if (C != 1) {
WorkspaceBundle workspace_bundle{
workspace.raw_ptr, {A * C * sizeof(Float32), A * C * sizeof(Float32)}};
Float32* max = workspace_bundle.get_workspace(0).raw_ptr->as<Float32>();
workspace.raw_ptr, {A * C * sizeof(float), A * C * sizeof(float)}};
float* max = workspace_bundle.get_workspace(0).raw_ptr->as<float>();
GI_FLOAT32_t v_max = GiBroadcastFloat32(float_min);
size_t i = 0;
for (; i + step <= A * C; i += step)
......@@ -60,8 +44,8 @@ void SoftmaxForwardImpl::exec(
}
}
Float32* sum = workspace_bundle.get_workspace(1).raw_ptr->as<Float32>();
memset(sum, 0, A * C * sizeof(Float32));
float* sum = workspace_bundle.get_workspace(1).raw_ptr->as<float>();
memset(sum, 0, A * C * sizeof(float));
for (size_t a = 0; a < A; a++) {
for (size_t b = 0; b < B; b++) {
auto max_ptr = max + a * C;
......@@ -157,6 +141,28 @@ void SoftmaxForwardImpl::exec(
}
}
void SoftmaxForwardImpl::exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) {
auto axis = param().axis;
if (axis < 0)
axis += src.layout.ndim;
megdnn_assert(axis >= 0);
check_exec(src.layout, dst.layout, workspace.size);
if (!usable(src.layout)) {
naive::SoftmaxForwardImpl::exec(src, dst, workspace);
return;
}
typedef DTypeTrait<dtype::Float32>::ctype Float32;
auto sptr = src.ptr<Float32>();
auto dptr = dst.ptr<Float32>();
size_t A, B, C;
reduce::get_ABC(src.layout, A, B, C, axis);
MEGDNN_DISPATCH_CPU_KERN_OPR(do_softmax(sptr, dptr, A, B, C, workspace));
}
} // namespace fallback
} // namespace megdnn
......
......@@ -1653,3 +1653,18 @@ def test_conv_transpose3d():
np.testing.assert_equal(
output_shape.numpy(), np.array([20, 33, 32, 96, 197], dtype=np.int32)
)
@pytest.mark.skip(reason="pytest aborted")
def test_softmax():
def np_softmax(x):
return np.exp(x) / np.sum(np.exp(x), axis=1, keepdims=True)
data = (np.random.random(size=(1, 16, 224, 224)).astype(np.float32) - 0.5) * 100
desired = np_softmax(data[:, :3, 0, 0])
data = Tensor(data)
data = data[:, :3, 0, 0]
actual = F.softmax(data)
np.testing.assert_allclose(actual.numpy(), desired, rtol=1e-5)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册