提交 446d54f5 编写于 作者: K Kexin Zhao

update

上级 ffa22a5f
...@@ -83,7 +83,7 @@ class BatchNormOp : public framework::OperatorWithKernel { ...@@ -83,7 +83,7 @@ class BatchNormOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("X")->type()); framework::ToDataType(ctx.Input<Tensor>("X")->type());
// For float or float16 input tensor, the type of the scale, bias, mean, // For float or float16 input tensor, the type of the scale, bias, mean,
......
...@@ -28,7 +28,7 @@ using DataLayout = framework::DataLayout; ...@@ -28,7 +28,7 @@ using DataLayout = framework::DataLayout;
template <typename T> template <typename T>
using CudnnDataType = platform::CudnnDataType<T>; using CudnnDataType = platform::CudnnDataType<T>;
template <typename T> template <typename T>
using ScalingParamType = typename CudnnDataType<T>::ScalingParamType; using BatchNormParamType = typename CudnnDataType<T>::BatchNormParamType;
void ExtractNCWHD(const framework::DDim &dims, const DataLayout &data_layout, void ExtractNCWHD(const framework::DDim &dims, const DataLayout &data_layout,
int *N, int *C, int *H, int *W, int *D) { int *N, int *C, int *H, int *W, int *D) {
...@@ -122,15 +122,16 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -122,15 +122,16 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
// alloc memory // alloc memory
y->mutable_data<T>(ctx.GetPlace()); y->mutable_data<T>(ctx.GetPlace());
mean_out->mutable_data<ScalingParamType<T>>(ctx.GetPlace()); mean_out->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
variance_out->mutable_data<ScalingParamType<T>>(ctx.GetPlace()); variance_out->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
saved_mean->mutable_data<ScalingParamType<T>>(ctx.GetPlace()); saved_mean->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
saved_variance->mutable_data<ScalingParamType<T>>(ctx.GetPlace()); saved_variance->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>(); auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
math::SetConstant<platform::CUDADeviceContext, ScalingParamType<T>> functor; math::SetConstant<platform::CUDADeviceContext, BatchNormParamType<T>>
functor(dev_ctx, saved_mean, static_cast<ScalingParamType<T>>(0)); functor;
functor(dev_ctx, saved_variance, static_cast<ScalingParamType<T>>(0)); functor(dev_ctx, saved_mean, static_cast<BatchNormParamType<T>>(0));
functor(dev_ctx, saved_variance, static_cast<BatchNormParamType<T>>(0));
auto handle = dev_ctx.cudnn_handle(); auto handle = dev_ctx.cudnn_handle();
...@@ -151,10 +152,10 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -151,10 +152,10 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
CUDNN_BATCHNORM_SPATIAL, CudnnDataType<T>::kOne(), CUDNN_BATCHNORM_SPATIAL, CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(), data_desc_, x->template data<T>(), CudnnDataType<T>::kZero(), data_desc_, x->template data<T>(),
data_desc_, y->template mutable_data<T>(ctx.GetPlace()), data_desc_, y->template mutable_data<T>(ctx.GetPlace()),
bn_param_desc_, scale->template data<ScalingParamType<T>>(), bn_param_desc_, scale->template data<BatchNormParamType<T>>(),
bias->template data<ScalingParamType<T>>(), bias->template data<BatchNormParamType<T>>(),
est_mean->template data<ScalingParamType<T>>(), est_mean->template data<BatchNormParamType<T>>(),
est_var->template data<ScalingParamType<T>>(), epsilon)); est_var->template data<BatchNormParamType<T>>(), epsilon));
} else { } else {
// Run training mode. // Run training mode.
// obtain running mean and running inv var, and see if we need to // obtain running mean and running inv var, and see if we need to
...@@ -165,14 +166,15 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -165,14 +166,15 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
handle, mode_, CudnnDataType<T>::kOne(), CudnnDataType<T>::kZero(), handle, mode_, CudnnDataType<T>::kOne(), CudnnDataType<T>::kZero(),
data_desc_, x->template data<T>(), data_desc_, data_desc_, x->template data<T>(), data_desc_,
y->template mutable_data<T>(ctx.GetPlace()), bn_param_desc_, y->template mutable_data<T>(ctx.GetPlace()), bn_param_desc_,
scale->template data<ScalingParamType<T>>(), scale->template data<BatchNormParamType<T>>(),
bias->template data<ScalingParamType<T>>(), this_factor, bias->template data<BatchNormParamType<T>>(), this_factor,
mean_out->template mutable_data<ScalingParamType<T>>(ctx.GetPlace()), mean_out->template mutable_data<BatchNormParamType<T>>(
variance_out->template mutable_data<ScalingParamType<T>>(
ctx.GetPlace()), ctx.GetPlace()),
epsilon, saved_mean->template mutable_data<ScalingParamType<T>>( variance_out->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
epsilon, saved_mean->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()), ctx.GetPlace()),
saved_variance->template mutable_data<ScalingParamType<T>>( saved_variance->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()))); ctx.GetPlace())));
} }
......
...@@ -86,7 +86,8 @@ class CudnnDataType<float16> { ...@@ -86,7 +86,8 @@ class CudnnDataType<float16> {
public: public:
static const cudnnDataType_t type = CUDNN_DATA_HALF; static const cudnnDataType_t type = CUDNN_DATA_HALF;
// The scaling param type is float for HALF and FLOAT tensors // The scaling param type is float for HALF and FLOAT tensors
typedef const float ScalingParamType; using ScalingParamType = const float;
using BatchNormParamType = float;
static ScalingParamType* kOne() { static ScalingParamType* kOne() {
static ScalingParamType v = 1.0; static ScalingParamType v = 1.0;
return &v; return &v;
...@@ -101,7 +102,8 @@ template <> ...@@ -101,7 +102,8 @@ template <>
class CudnnDataType<float> { class CudnnDataType<float> {
public: public:
static const cudnnDataType_t type = CUDNN_DATA_FLOAT; static const cudnnDataType_t type = CUDNN_DATA_FLOAT;
typedef const float ScalingParamType; using ScalingParamType = const float;
using BatchNormParamType = float;
static ScalingParamType* kOne() { static ScalingParamType* kOne() {
static ScalingParamType v = 1.0; static ScalingParamType v = 1.0;
return &v; return &v;
...@@ -116,7 +118,8 @@ template <> ...@@ -116,7 +118,8 @@ template <>
class CudnnDataType<double> { class CudnnDataType<double> {
public: public:
static const cudnnDataType_t type = CUDNN_DATA_DOUBLE; static const cudnnDataType_t type = CUDNN_DATA_DOUBLE;
typedef const double ScalingParamType; using ScalingParamType = const double;
using BatchNormParamType = double;
static ScalingParamType* kOne() { static ScalingParamType* kOne() {
static ScalingParamType v = 1.0; static ScalingParamType v = 1.0;
return &v; return &v;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册