提交 ffa22a5f 编写于 作者: K Kexin Zhao

fix scaling param type

上级 e870947c
......@@ -18,6 +18,7 @@ limitations under the License. */
#include <cfloat>
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
......@@ -27,7 +28,7 @@ using DataLayout = framework::DataLayout;
template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;
template <typename T>
using bn_param_type = CudnnDataType<T>::bn_param_type;
using ScalingParamType = typename CudnnDataType<T>::ScalingParamType;
void ExtractNCWHD(const framework::DDim &dims, const DataLayout &data_layout,
int *N, int *C, int *H, int *W, int *D) {
......@@ -121,15 +122,15 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
// alloc memory
y->mutable_data<T>(ctx.GetPlace());
mean_out->mutable_data<bn_param_type<T>>(ctx.GetPlace());
variance_out->mutable_data<bn_param_type<T>>(ctx.GetPlace());
saved_mean->mutable_data<bn_param_type<T>>(ctx.GetPlace());
saved_variance->mutable_data<bn_param_type<T>>(ctx.GetPlace());
mean_out->mutable_data<ScalingParamType<T>>(ctx.GetPlace());
variance_out->mutable_data<ScalingParamType<T>>(ctx.GetPlace());
saved_mean->mutable_data<ScalingParamType<T>>(ctx.GetPlace());
saved_variance->mutable_data<ScalingParamType<T>>(ctx.GetPlace());
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
math::SetConstant<platform::CUDADeviceContext, bn_param_type<T>> functor;
functor(dev_ctx, saved_mean, static_cast<bn_param_type<T>>(0));
functor(dev_ctx, saved_variance, static_cast<bn_param_type<T>>(0));
math::SetConstant<platform::CUDADeviceContext, ScalingParamType<T>> functor;
functor(dev_ctx, saved_mean, static_cast<ScalingParamType<T>>(0));
functor(dev_ctx, saved_variance, static_cast<ScalingParamType<T>>(0));
auto handle = dev_ctx.cudnn_handle();
......@@ -150,10 +151,10 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
CUDNN_BATCHNORM_SPATIAL, CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(), data_desc_, x->template data<T>(),
data_desc_, y->template mutable_data<T>(ctx.GetPlace()),
bn_param_desc_, scale->template data<bn_param_type<T>>(),
bias->template data<bn_param_type<T>>(),
est_mean->template data<bn_param_type<T>>(),
est_var->template data<bn_param_type<T>>(), epsilon));
bn_param_desc_, scale->template data<ScalingParamType<T>>(),
bias->template data<ScalingParamType<T>>(),
est_mean->template data<ScalingParamType<T>>(),
est_var->template data<ScalingParamType<T>>(), epsilon));
} else {
// Run training mode.
// obtain running mean and running inv var, and see if we need to
......@@ -164,13 +165,14 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
handle, mode_, CudnnDataType<T>::kOne(), CudnnDataType<T>::kZero(),
data_desc_, x->template data<T>(), data_desc_,
y->template mutable_data<T>(ctx.GetPlace()), bn_param_desc_,
scale->template data<bn_param_type<T>>(),
bias->template data<bn_param_type<T>>(), this_factor,
mean_out->template mutable_data<bn_param_type<T>>(ctx.GetPlace()),
variance_out->template mutable_data<bn_param_type<T>>(ctx.GetPlace()),
epsilon,
saved_mean->template mutable_data<bn_param_type<T>>(ctx.GetPlace()),
saved_variance->template mutable_data<bn_param_type<T>>(
scale->template data<ScalingParamType<T>>(),
bias->template data<ScalingParamType<T>>(), this_factor,
mean_out->template mutable_data<ScalingParamType<T>>(ctx.GetPlace()),
variance_out->template mutable_data<ScalingParamType<T>>(
ctx.GetPlace()),
epsilon, saved_mean->template mutable_data<ScalingParamType<T>>(
ctx.GetPlace()),
saved_variance->template mutable_data<ScalingParamType<T>>(
ctx.GetPlace())));
}
......
......@@ -85,9 +85,6 @@ template <>
class CudnnDataType<float16> {
public:
static const cudnnDataType_t type = CUDNN_DATA_HALF;
// cudnn batch norm requires that Scale, Bias, Mean, and Variance
// to be FLOAT tensors when the input x is HALF tensor
static const cudnnDataType_t bn_param_type = CUDNN_DATA_FLOAT;
// The scaling param type is float for HALF and FLOAT tensors
typedef const float ScalingParamType;
static ScalingParamType* kOne() {
......@@ -104,7 +101,6 @@ template <>
class CudnnDataType<float> {
public:
static const cudnnDataType_t type = CUDNN_DATA_FLOAT;
static const cudnnDataType_t bn_param_type = CUDNN_DATA_FLOAT;
typedef const float ScalingParamType;
static ScalingParamType* kOne() {
static ScalingParamType v = 1.0;
......@@ -120,7 +116,6 @@ template <>
class CudnnDataType<double> {
public:
static const cudnnDataType_t type = CUDNN_DATA_DOUBLE;
static const cudnnDataType_t bn_param_type = CUDNN_DATA_DOUBLE;
typedef const double ScalingParamType;
static ScalingParamType* kOne() {
static ScalingParamType v = 1.0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册