From 89303cd829b545e38a9303c5e6a91983ab5821eb Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 4 Nov 2020 18:07:41 +0800 Subject: [PATCH] feat(megdnn/rocm): add bn for rocm backend GitOrigin-RevId: 8bd49599b28847e51735ff57293bc16ff0033924 --- cmake/rocm.cmake | 21 ++-- dnn/src/rocm/batch_normalization/opr_impl.cpp | 116 ++++++++++++++++++ dnn/src/rocm/batch_normalization/opr_impl.h | 80 ++++++++++++ dnn/src/rocm/handle.cpp | 3 + dnn/test/rocm/bn.cpp | 71 +++++++++++ 5 files changed, 281 insertions(+), 10 deletions(-) create mode 100644 dnn/src/rocm/batch_normalization/opr_impl.cpp create mode 100644 dnn/src/rocm/batch_normalization/opr_impl.h create mode 100644 dnn/test/rocm/bn.cpp diff --git a/cmake/rocm.cmake b/cmake/rocm.cmake index db5c0e748..e5f49b116 100644 --- a/cmake/rocm.cmake +++ b/cmake/rocm.cmake @@ -17,10 +17,11 @@ string(REPLACE "." ";" HIP_VERSION_LIST ${HIP_VERSION}) list(GET HIP_VERSION_LIST 0 HIP_VERSION_MAJOR) list(GET HIP_VERSION_LIST 1 HIP_VERSION_MINOR) if (NOT ${HIP_VERSION_MAJOR} STREQUAL "3") - message(FATAL_ERROR "ROCM version needed 3.7.Please update ROCM.") -endif() -if (NOT ${HIP_VERSION_MINOR} STREQUAL "7") - message(FATAL_ERROR "ROCM version needed 3.7.Please update ROCM.") + message(FATAL_ERROR "ROCM version needed 3.x, Please update ROCM.") +else() + if (${HIP_VERSION_MINOR} LESS "7") + message(WARNING "ROCM version 3.x which x(got ${HIP_VERSION_MINOR}) greater equal 7 is prefered.") + endif() endif() set(MGE_ROCM_LIBS OpenCL amdhip64 MIOpen rocblas rocrand) @@ -37,7 +38,7 @@ find_path(MIOPEN_LIBRARY_DIR DOC "Path to MIOPEN library directory." ) if(MIOPEN_LIBRARY_DIR STREQUAL "MIOPEN_LIBRARY_DIR-NOTFOUND") - message(FATAL_ERROR "Can not find MIOPEN Library") + message(FATAL_ERROR "Can not find MIOPEN Library") endif() get_filename_component(__found_miopen_include ${HIP_ROOT_DIR}/../miopen/include REALPATH) @@ -48,7 +49,7 @@ find_path(MIOPEN_INCLUDE_DIR DOC "Path to MIOPEN include directory." ) if(MIOPEN_INCLUDE_DIR STREQUAL "MIOPEN_INCLUDE_DIR-NOTFOUND") - message(FATAL_ERROR "Can not find MIOEPN INCLUDE") + message(FATAL_ERROR "Can not find MIOEPN INCLUDE") endif() #rocblas @@ -60,7 +61,7 @@ find_path(ROCBLAS_LIBRARY_DIR DOC "Path to ROCBLAS library directory." ) if(ROCBLAS_LIBRARY_DIR STREQUAL "ROCBLAS_LIBRARY_DIR-NOTFOUND") - message(FATAL_ERROR "Can not find ROCBLAS Library") + message(FATAL_ERROR "Can not find ROCBLAS Library") endif() get_filename_component(__found_rocblas_include ${HIP_ROOT_DIR}/../rocblas/include REALPATH) @@ -71,7 +72,7 @@ find_path(ROCBLAS_INCLUDE_DIR DOC "Path to ROCBLAS include directory." ) if(ROCBLAS_INCLUDE_DIR STREQUAL "ROCBLAS_INCLUDE_DIR-NOTFOUND") - message(FATAL_ERROR "Can not find ROCBLAS INCLUDE") + message(FATAL_ERROR "Can not find ROCBLAS INCLUDE") endif() #rocrand @@ -83,7 +84,7 @@ find_path(ROCRAND_LIBRARY_DIR DOC "Path to ROCRAND library directory." ) if(ROCRAND_LIBRARY_DIR STREQUAL "ROCRAND_LIBRARY_DIR-NOTFOUND") - message(FATAL_ERROR "Can not find ROCRAND Library") + message(FATAL_ERROR "Can not find ROCRAND Library") endif() get_filename_component(__found_rocrand_include ${HIP_ROOT_DIR}/../rocrand/include REALPATH) @@ -94,7 +95,7 @@ find_path(ROCRAND_INCLUDE_DIR DOC "Path to ROCRAND include directory." ) if(ROCRAND_INCLUDE_DIR STREQUAL "ROCRAND_INCLUDE_DIR-NOTFOUND") - message(FATAL_ERROR "Can not find ROCRAND INCLUDE") + message(FATAL_ERROR "Can not find ROCRAND INCLUDE") endif() diff --git a/dnn/src/rocm/batch_normalization/opr_impl.cpp b/dnn/src/rocm/batch_normalization/opr_impl.cpp new file mode 100644 index 000000000..090854d18 --- /dev/null +++ b/dnn/src/rocm/batch_normalization/opr_impl.cpp @@ -0,0 +1,116 @@ +/** + * \file dnn/src/rocm/batch_normalization/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "./opr_impl.h" + +#include "src/rocm/utils.h" + +namespace megdnn { +namespace rocm { + +namespace batch_normalization { + +void BNTensorDescHolder::setup(const TensorLayout& x, + const ParamDim& param_dim) { + TensorShape xy_shape(x); + + switch (param_dim) { + case ParamDim::DIM_11HW: + // xy: N, C, H, W --> (N*C), 1, H, W + xy_shape.shape[0] = xy_shape.shape[0] * xy_shape.shape[1]; + xy_shape.shape[1] = 1; + bn_mode = miopenBNPerActivation; + break; + case ParamDim::DIM_1CHW: + bn_mode = miopenBNPerActivation; + break; + case ParamDim::DIM_1C11: + bn_mode = miopenBNSpatial; + break; + default: + megdnn_throw(megdnn_mangle( + "Unknown param dim type of batch normalization.")); + } + xy_desc.set(TensorLayout(xy_shape, x.dtype)); + param_desc.set(xy_desc.desc, bn_mode); +} + +} // namespace batch_normalization + +void BNForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, + _megdnn_tensor_in bn_bias, _megdnn_tensor_out mean, + _megdnn_tensor_out variance, + _megdnn_tensor_out batch_mean, + _megdnn_tensor_out batch_inv_variance, + _megdnn_tensor_out dst, _megdnn_workspace workspace) { + check_exec(src.layout, bn_scale.layout, bn_bias.layout, mean.layout, + variance.layout, batch_mean.layout, batch_inv_variance.layout, + dst.layout, workspace.size); + auto handle = concrete_handle(this->handle())->miopen_handle(); + m_tensor_desc.setup(src.layout, m_param.param_dim); + + float alpha = 1.0f, beta = 0.0f; + switch (m_param.fwd_mode) { + case param::BN::FwdMode::TRAINING: + miopen_check(miopenBatchNormalizationForwardTraining( + handle, m_tensor_desc.bn_mode, &alpha, &beta, + m_tensor_desc.xy_desc.desc, // xDesc + src.raw_ptr, // x + m_tensor_desc.xy_desc.desc, // yDesc + dst.raw_ptr, // y + m_tensor_desc.param_desc.desc, // bnScaleBiasMeanVarDesc + bn_scale.raw_ptr, bn_bias.raw_ptr, m_param.avg_factor, + mean.raw_ptr, variance.raw_ptr, m_param.epsilon, + batch_mean.raw_ptr, batch_inv_variance.raw_ptr)); + + break; + case param::BN::FwdMode::INFERENCE: + miopen_check(miopenBatchNormalizationForwardInference( + handle, m_tensor_desc.bn_mode, &alpha, &beta, + m_tensor_desc.xy_desc.desc, src.raw_ptr, + m_tensor_desc.xy_desc.desc, dst.raw_ptr, + m_tensor_desc.param_desc.desc, bn_scale.raw_ptr, + bn_bias.raw_ptr, mean.raw_ptr, variance.raw_ptr, + m_param.epsilon)); + break; + default: + megdnn_throw(megdnn_mangle( + "Unknown forward mode type of batch normalization.")); + } +} + +void BNBackwardImpl::exec(_megdnn_tensor_in x, _megdnn_tensor_in dy, + _megdnn_tensor_in saved_batch_mean, + _megdnn_tensor_in saved_batch_inv_variance, + _megdnn_tensor_in bn_scale, + _megdnn_tensor_out d_bn_scale, + _megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx, + _megdnn_workspace workspace) { + check_exec(x.layout, dy.layout, saved_batch_mean.layout, + saved_batch_inv_variance.layout, bn_scale.layout, + d_bn_scale.layout, d_bn_bias.layout, dx.layout, workspace.size); + auto handle = concrete_handle(this->handle())->miopen_handle(); + m_tensor_desc.setup(x.layout, m_param.param_dim); + + float alpha = 1.0, beta = 0.0; + miopen_check(miopenBatchNormalizationBackward( + handle, m_tensor_desc.bn_mode, &alpha, &beta, &alpha, &beta, + m_tensor_desc.xy_desc.desc, x.raw_ptr, m_tensor_desc.xy_desc.desc, + dy.raw_ptr, m_tensor_desc.xy_desc.desc, dx.raw_ptr, + m_tensor_desc.param_desc.desc, bn_scale.raw_ptr, d_bn_scale.raw_ptr, + d_bn_bias.raw_ptr, m_param.epsilon, saved_batch_mean.raw_ptr, + saved_batch_inv_variance.raw_ptr)); +} + +} // namespace rocm +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/batch_normalization/opr_impl.h b/dnn/src/rocm/batch_normalization/opr_impl.h new file mode 100644 index 000000000..1bae0bd5b --- /dev/null +++ b/dnn/src/rocm/batch_normalization/opr_impl.h @@ -0,0 +1,80 @@ +/** + * \file dnn/src/rocm/batch_normalization/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include "megdnn/oprs.h" + +#include "src/rocm/miopen_wrapper.h" + +namespace megdnn { +namespace rocm { + +namespace batch_normalization { + +struct BNTensorDescHolder { + using ParamDim = param::BN::ParamDim; + + TensorDesc xy_desc; + BNParamDesc param_desc; + miopenBatchNormMode_t bn_mode; + + void setup(const TensorLayout& x, const ParamDim& param_dim); +}; + +} // namespace batch_normalization + +class BNForwardImpl final : public BNForward { +public: + using BNForward::BNForward; + void exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, + _megdnn_tensor_in bn_bias, _megdnn_tensor_out mean, + _megdnn_tensor_out variance, _megdnn_tensor_out batch_mean, + _megdnn_tensor_out batch_inv_variance, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; + + size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, + const TensorLayout&, const TensorLayout&, + const TensorLayout&, const TensorLayout&, + const TensorLayout&, + const TensorLayout&) override { + return 0; + } + +private: + batch_normalization::BNTensorDescHolder m_tensor_desc; +}; + +class BNBackwardImpl final : public BNBackward { +public: + using BNBackward::BNBackward; + void exec(_megdnn_tensor_in x, _megdnn_tensor_in dy, + _megdnn_tensor_in saved_batch_mean, + _megdnn_tensor_in saved_batch_inv_variance, + _megdnn_tensor_in bn_scale, _megdnn_tensor_out d_bn_scale, + _megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx, + _megdnn_workspace workspace) override; + + size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, + const TensorLayout&, const TensorLayout&, + const TensorLayout&, const TensorLayout&, + const TensorLayout&, + const TensorLayout&) override { + return 0; + } + +private: + batch_normalization::BNTensorDescHolder m_tensor_desc; +}; + +} // namespace rocm +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/handle.cpp b/dnn/src/rocm/handle.cpp index 62da48c28..802d75c7f 100644 --- a/dnn/src/rocm/handle.cpp +++ b/dnn/src/rocm/handle.cpp @@ -35,6 +35,7 @@ #include "src/rocm/linspace/opr_impl.h" #include "src/rocm/argmxx/opr_impl.h" #include "src/rocm/sleep/opr_impl.h" +#include "src/rocm/batch_normalization/opr_impl.h" #include #include @@ -171,6 +172,8 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(Linspace); MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgmaxForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgminForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(SleepForward); +MEGDNN_SPECIALIZE_CREATE_OPERATOR(BNForward); +MEGDNN_SPECIALIZE_CREATE_OPERATOR(BNBackward); #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wpragmas" diff --git a/dnn/test/rocm/bn.cpp b/dnn/test/rocm/bn.cpp new file mode 100644 index 000000000..b84b2ae0b --- /dev/null +++ b/dnn/test/rocm/bn.cpp @@ -0,0 +1,71 @@ +/** + * \file dnn/test/rocm/bn.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "test/rocm/fixture.h" + +#include "megdnn/opr_param_defs.h" +#include "megdnn/oprs.h" +#include "test/common/bn.h" +#include "test/common/checker.h" +#include "test/common/rng.h" +#include "test/common/tensor.h" +#include "test/common/workspace_wrapper.h" + +namespace megdnn { +namespace test { + +TEST_F(ROCM, BN_FORWARD) { + using namespace batch_normalization; + std::vector args = get_args(); + Checker checker(handle_rocm()); + for (auto&& arg : args) { + for (int i = 0; i < 8; ++i) { + checker.set_dtype(i, dtype::Float32()); + } + checker.set_dtype(0, arg.dtype); + checker.set_epsilon(1e-3).set_param(arg.param); + for (bool need_statistic : {false, true}) + checker.exec({ + arg.src, + arg.param_shape, // bn_scale + arg.param_shape, // bn_bias + need_statistic ? arg.param_shape + : TensorShape({0}), // mean + need_statistic ? arg.param_shape + : TensorShape({0}), // variance + arg.param_shape, // batch_mean + arg.param_shape, // batch_inv_variance + {} // dst + }); + } +} + +TEST_F(ROCM, BN_BACKWARD) { + using namespace batch_normalization; + std::vector args = get_args(); + Checker checker(handle_rocm()); + for (auto&& arg : args) { + for (int i = 0; i < 8; ++i) { + checker.set_dtype(i, dtype::Float32()); + } + checker.set_dtype(0, arg.dtype) // x + .set_dtype(1, arg.dtype) // dy + .set_dtype(7, arg.dtype); // dx + checker.set_epsilon(1e-3).set_param(arg.param).exec( + {arg.src, arg.src, arg.param_shape, arg.param_shape, + arg.param_shape, arg.param_shape, arg.param_shape, arg.src}); + } +} + +} // namespace test +} // namespace megdnn + +// vim: syntax=cpp.doxygen -- GitLab