提交 89303cd8 编写于 作者: M Megvii Engine Team

feat(megdnn/rocm): add bn for rocm backend

GitOrigin-RevId: 8bd49599b28847e51735ff57293bc16ff0033924
上级 7cd846a5
......@@ -17,10 +17,11 @@ string(REPLACE "." ";" HIP_VERSION_LIST ${HIP_VERSION})
list(GET HIP_VERSION_LIST 0 HIP_VERSION_MAJOR)
list(GET HIP_VERSION_LIST 1 HIP_VERSION_MINOR)
if (NOT ${HIP_VERSION_MAJOR} STREQUAL "3")
message(FATAL_ERROR "ROCM version needed 3.7.Please update ROCM.")
endif()
if (NOT ${HIP_VERSION_MINOR} STREQUAL "7")
message(FATAL_ERROR "ROCM version needed 3.7.Please update ROCM.")
message(FATAL_ERROR "ROCM version needed 3.x, Please update ROCM.")
else()
if (${HIP_VERSION_MINOR} LESS "7")
message(WARNING "ROCM version 3.x which x(got ${HIP_VERSION_MINOR}) greater equal 7 is prefered.")
endif()
endif()
set(MGE_ROCM_LIBS OpenCL amdhip64 MIOpen rocblas rocrand)
......@@ -37,7 +38,7 @@ find_path(MIOPEN_LIBRARY_DIR
DOC "Path to MIOPEN library directory." )
if(MIOPEN_LIBRARY_DIR STREQUAL "MIOPEN_LIBRARY_DIR-NOTFOUND")
message(FATAL_ERROR "Can not find MIOPEN Library")
message(FATAL_ERROR "Can not find MIOPEN Library")
endif()
get_filename_component(__found_miopen_include ${HIP_ROOT_DIR}/../miopen/include REALPATH)
......@@ -48,7 +49,7 @@ find_path(MIOPEN_INCLUDE_DIR
DOC "Path to MIOPEN include directory." )
if(MIOPEN_INCLUDE_DIR STREQUAL "MIOPEN_INCLUDE_DIR-NOTFOUND")
message(FATAL_ERROR "Can not find MIOEPN INCLUDE")
message(FATAL_ERROR "Can not find MIOEPN INCLUDE")
endif()
#rocblas
......@@ -60,7 +61,7 @@ find_path(ROCBLAS_LIBRARY_DIR
DOC "Path to ROCBLAS library directory." )
if(ROCBLAS_LIBRARY_DIR STREQUAL "ROCBLAS_LIBRARY_DIR-NOTFOUND")
message(FATAL_ERROR "Can not find ROCBLAS Library")
message(FATAL_ERROR "Can not find ROCBLAS Library")
endif()
get_filename_component(__found_rocblas_include ${HIP_ROOT_DIR}/../rocblas/include REALPATH)
......@@ -71,7 +72,7 @@ find_path(ROCBLAS_INCLUDE_DIR
DOC "Path to ROCBLAS include directory." )
if(ROCBLAS_INCLUDE_DIR STREQUAL "ROCBLAS_INCLUDE_DIR-NOTFOUND")
message(FATAL_ERROR "Can not find ROCBLAS INCLUDE")
message(FATAL_ERROR "Can not find ROCBLAS INCLUDE")
endif()
#rocrand
......@@ -83,7 +84,7 @@ find_path(ROCRAND_LIBRARY_DIR
DOC "Path to ROCRAND library directory." )
if(ROCRAND_LIBRARY_DIR STREQUAL "ROCRAND_LIBRARY_DIR-NOTFOUND")
message(FATAL_ERROR "Can not find ROCRAND Library")
message(FATAL_ERROR "Can not find ROCRAND Library")
endif()
get_filename_component(__found_rocrand_include ${HIP_ROOT_DIR}/../rocrand/include REALPATH)
......@@ -94,7 +95,7 @@ find_path(ROCRAND_INCLUDE_DIR
DOC "Path to ROCRAND include directory." )
if(ROCRAND_INCLUDE_DIR STREQUAL "ROCRAND_INCLUDE_DIR-NOTFOUND")
message(FATAL_ERROR "Can not find ROCRAND INCLUDE")
message(FATAL_ERROR "Can not find ROCRAND INCLUDE")
endif()
/**
* \file dnn/src/rocm/batch_normalization/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "./opr_impl.h"
#include "src/rocm/utils.h"
namespace megdnn {
namespace rocm {
namespace batch_normalization {
void BNTensorDescHolder::setup(const TensorLayout& x,
const ParamDim& param_dim) {
TensorShape xy_shape(x);
switch (param_dim) {
case ParamDim::DIM_11HW:
// xy: N, C, H, W --> (N*C), 1, H, W
xy_shape.shape[0] = xy_shape.shape[0] * xy_shape.shape[1];
xy_shape.shape[1] = 1;
bn_mode = miopenBNPerActivation;
break;
case ParamDim::DIM_1CHW:
bn_mode = miopenBNPerActivation;
break;
case ParamDim::DIM_1C11:
bn_mode = miopenBNSpatial;
break;
default:
megdnn_throw(megdnn_mangle(
"Unknown param dim type of batch normalization."));
}
xy_desc.set(TensorLayout(xy_shape, x.dtype));
param_desc.set(xy_desc.desc, bn_mode);
}
} // namespace batch_normalization
void BNForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale,
_megdnn_tensor_in bn_bias, _megdnn_tensor_out mean,
_megdnn_tensor_out variance,
_megdnn_tensor_out batch_mean,
_megdnn_tensor_out batch_inv_variance,
_megdnn_tensor_out dst, _megdnn_workspace workspace) {
check_exec(src.layout, bn_scale.layout, bn_bias.layout, mean.layout,
variance.layout, batch_mean.layout, batch_inv_variance.layout,
dst.layout, workspace.size);
auto handle = concrete_handle(this->handle())->miopen_handle();
m_tensor_desc.setup(src.layout, m_param.param_dim);
float alpha = 1.0f, beta = 0.0f;
switch (m_param.fwd_mode) {
case param::BN::FwdMode::TRAINING:
miopen_check(miopenBatchNormalizationForwardTraining(
handle, m_tensor_desc.bn_mode, &alpha, &beta,
m_tensor_desc.xy_desc.desc, // xDesc
src.raw_ptr, // x
m_tensor_desc.xy_desc.desc, // yDesc
dst.raw_ptr, // y
m_tensor_desc.param_desc.desc, // bnScaleBiasMeanVarDesc
bn_scale.raw_ptr, bn_bias.raw_ptr, m_param.avg_factor,
mean.raw_ptr, variance.raw_ptr, m_param.epsilon,
batch_mean.raw_ptr, batch_inv_variance.raw_ptr));
break;
case param::BN::FwdMode::INFERENCE:
miopen_check(miopenBatchNormalizationForwardInference(
handle, m_tensor_desc.bn_mode, &alpha, &beta,
m_tensor_desc.xy_desc.desc, src.raw_ptr,
m_tensor_desc.xy_desc.desc, dst.raw_ptr,
m_tensor_desc.param_desc.desc, bn_scale.raw_ptr,
bn_bias.raw_ptr, mean.raw_ptr, variance.raw_ptr,
m_param.epsilon));
break;
default:
megdnn_throw(megdnn_mangle(
"Unknown forward mode type of batch normalization."));
}
}
void BNBackwardImpl::exec(_megdnn_tensor_in x, _megdnn_tensor_in dy,
_megdnn_tensor_in saved_batch_mean,
_megdnn_tensor_in saved_batch_inv_variance,
_megdnn_tensor_in bn_scale,
_megdnn_tensor_out d_bn_scale,
_megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx,
_megdnn_workspace workspace) {
check_exec(x.layout, dy.layout, saved_batch_mean.layout,
saved_batch_inv_variance.layout, bn_scale.layout,
d_bn_scale.layout, d_bn_bias.layout, dx.layout, workspace.size);
auto handle = concrete_handle(this->handle())->miopen_handle();
m_tensor_desc.setup(x.layout, m_param.param_dim);
float alpha = 1.0, beta = 0.0;
miopen_check(miopenBatchNormalizationBackward(
handle, m_tensor_desc.bn_mode, &alpha, &beta, &alpha, &beta,
m_tensor_desc.xy_desc.desc, x.raw_ptr, m_tensor_desc.xy_desc.desc,
dy.raw_ptr, m_tensor_desc.xy_desc.desc, dx.raw_ptr,
m_tensor_desc.param_desc.desc, bn_scale.raw_ptr, d_bn_scale.raw_ptr,
d_bn_bias.raw_ptr, m_param.epsilon, saved_batch_mean.raw_ptr,
saved_batch_inv_variance.raw_ptr));
}
} // namespace rocm
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/rocm/batch_normalization/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/oprs.h"
#include "src/rocm/miopen_wrapper.h"
namespace megdnn {
namespace rocm {
namespace batch_normalization {
struct BNTensorDescHolder {
using ParamDim = param::BN::ParamDim;
TensorDesc xy_desc;
BNParamDesc param_desc;
miopenBatchNormMode_t bn_mode;
void setup(const TensorLayout& x, const ParamDim& param_dim);
};
} // namespace batch_normalization
class BNForwardImpl final : public BNForward {
public:
using BNForward::BNForward;
void exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale,
_megdnn_tensor_in bn_bias, _megdnn_tensor_out mean,
_megdnn_tensor_out variance, _megdnn_tensor_out batch_mean,
_megdnn_tensor_out batch_inv_variance, _megdnn_tensor_out dst,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&,
const TensorLayout&,
const TensorLayout&) override {
return 0;
}
private:
batch_normalization::BNTensorDescHolder m_tensor_desc;
};
class BNBackwardImpl final : public BNBackward {
public:
using BNBackward::BNBackward;
void exec(_megdnn_tensor_in x, _megdnn_tensor_in dy,
_megdnn_tensor_in saved_batch_mean,
_megdnn_tensor_in saved_batch_inv_variance,
_megdnn_tensor_in bn_scale, _megdnn_tensor_out d_bn_scale,
_megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&,
const TensorLayout&,
const TensorLayout&) override {
return 0;
}
private:
batch_normalization::BNTensorDescHolder m_tensor_desc;
};
} // namespace rocm
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -35,6 +35,7 @@
#include "src/rocm/linspace/opr_impl.h"
#include "src/rocm/argmxx/opr_impl.h"
#include "src/rocm/sleep/opr_impl.h"
#include "src/rocm/batch_normalization/opr_impl.h"
#include <miopen/version.h>
#include <hip/hip_version.h>
......@@ -171,6 +172,8 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(Linspace);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgmaxForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgminForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(SleepForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(BNForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(BNBackward);
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wpragmas"
......
/**
* \file dnn/test/rocm/bn.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "test/rocm/fixture.h"
#include "megdnn/opr_param_defs.h"
#include "megdnn/oprs.h"
#include "test/common/bn.h"
#include "test/common/checker.h"
#include "test/common/rng.h"
#include "test/common/tensor.h"
#include "test/common/workspace_wrapper.h"
namespace megdnn {
namespace test {
TEST_F(ROCM, BN_FORWARD) {
using namespace batch_normalization;
std::vector<TestArg> args = get_args();
Checker<BNForward> checker(handle_rocm());
for (auto&& arg : args) {
for (int i = 0; i < 8; ++i) {
checker.set_dtype(i, dtype::Float32());
}
checker.set_dtype(0, arg.dtype);
checker.set_epsilon(1e-3).set_param(arg.param);
for (bool need_statistic : {false, true})
checker.exec({
arg.src,
arg.param_shape, // bn_scale
arg.param_shape, // bn_bias
need_statistic ? arg.param_shape
: TensorShape({0}), // mean
need_statistic ? arg.param_shape
: TensorShape({0}), // variance
arg.param_shape, // batch_mean
arg.param_shape, // batch_inv_variance
{} // dst
});
}
}
TEST_F(ROCM, BN_BACKWARD) {
using namespace batch_normalization;
std::vector<TestArg> args = get_args();
Checker<BNBackward> checker(handle_rocm());
for (auto&& arg : args) {
for (int i = 0; i < 8; ++i) {
checker.set_dtype(i, dtype::Float32());
}
checker.set_dtype(0, arg.dtype) // x
.set_dtype(1, arg.dtype) // dy
.set_dtype(7, arg.dtype); // dx
checker.set_epsilon(1e-3).set_param(arg.param).exec(
{arg.src, arg.src, arg.param_shape, arg.param_shape,
arg.param_shape, arg.param_shape, arg.param_shape, arg.src});
}
}
} // namespace test
} // namespace megdnn
// vim: syntax=cpp.doxygen
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册