未验证 提交 318dfa0d 编写于 作者: Z Zhang Ting 提交者: GitHub

remove eval in eigen function when dtype is fp16 (#23845)

上级 66dc8e30
......@@ -37,7 +37,9 @@ struct ValueClip {
};
template <typename DeviceContext, typename T, bool is_test>
void SoftmaxEigen(const DeviceContext& context, const int axis_dim,
class SoftmaxEigen {
public:
void operator()(const DeviceContext& context, const int axis_dim,
const framework::Tensor* X, framework::Tensor* Y) {
constexpr int kBatchDim = 0;
constexpr int kClassDim = 1;
......@@ -85,19 +87,80 @@ void SoftmaxEigen(const DeviceContext& context, const int axis_dim,
}
softmax.device(*context.eigen_device()) = softmax.exp();
softmax.device(*context.eigen_device()) = (softmax *
softmax.device(*context.eigen_device()) =
(softmax *
softmax.reshape(batch_axis_remain)
.sum(along_axis)
.inverse()
.eval()
.broadcast(one_axis));
}
}
};
template <typename DeviceContext, bool is_test>
class SoftmaxEigen<DeviceContext, platform::float16, is_test> {
public:
void operator()(const DeviceContext& context, const int axis_dim,
const framework::Tensor* X, framework::Tensor* Y) {
constexpr int kBatchDim = 0;
constexpr int kClassDim = 1;
constexpr int kAxisDim = 1;
auto logits = EigenMatrix<platform::float16>::From(*X);
auto softmax = EigenMatrix<platform::float16>::From(*Y);
const int batch_size = logits.dimension(kBatchDim);
const int num_classes = logits.dimension(kClassDim);
const int num_remain = num_classes / axis_dim;
Eigen::DSizes<int, 1> along_axis(kAxisDim);
Eigen::DSizes<int, 2> batch_classes(batch_size, num_classes);
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
Eigen::DSizes<int, 2> one_by_class(1, num_classes);
Eigen::DSizes<int, 3> batch_one_remain(batch_size, 1, num_remain);
Eigen::DSizes<int, 3> one_axis_one(1, axis_dim, 1);
Eigen::DSizes<int, 2> one_axis(1, axis_dim);
Eigen::DSizes<int, 3> batch_axis_remain(batch_size, axis_dim, num_remain);
// For numerical stability, logits should be shifted by maximum number along
// axis, calculate shifted_logits into softmax tensor for memory reuse.
if (num_remain == 1) {
// axis == -1, axis and class in same dimension, calculate along
// class dimension directly for higher performance
softmax.device(*context.eigen_device()) =
(logits -
logits.maximum(along_axis)
.reshape(batch_by_one)
.broadcast(one_by_class))
.unaryExpr(ValueClip<platform::float16>());
} else {
// axis != -1, class dimension split into (axis, remain), max and sum
// should be calculated along axis dimension
softmax.device(*context.eigen_device()) =
(logits.reshape(batch_axis_remain) -
logits.reshape(batch_axis_remain)
.maximum(along_axis)
.reshape(batch_one_remain)
.broadcast(one_axis_one)
.reshape(batch_classes))
.unaryExpr(ValueClip<platform::float16>());
}
softmax.device(*context.eigen_device()) = softmax.exp();
softmax.device(*context.eigen_device()) =
(softmax *
softmax.reshape(batch_axis_remain)
.sum(along_axis)
.inverse()
.broadcast(one_axis));
}
};
template <typename DeviceContext, typename T, bool is_test, typename Enable>
void SoftmaxFunctor<DeviceContext, T, is_test, Enable>::operator()(
const DeviceContext& context, const int axis_dim,
const framework::Tensor* X, framework::Tensor* Y) {
SoftmaxEigen<DeviceContext, T, is_test>(context, axis_dim, X, Y);
SoftmaxEigen<DeviceContext, T, is_test>()(context, axis_dim, X, Y);
}
template <class DeviceContext>
......@@ -137,7 +200,7 @@ class SoftmaxFunctor<DeviceContext, T, is_test, enable_if_CPU<DeviceContext>> {
out_data += num_classes;
}
} else {
SoftmaxEigen<DeviceContext, T, is_test>(context, axis_dim, X, Y);
SoftmaxEigen<DeviceContext, T, is_test>()(context, axis_dim, X, Y);
}
}
};
......@@ -162,9 +225,10 @@ class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> {
};
template <typename DeviceContext, typename T>
void SoftmaxGradEigen(const DeviceContext& context, const int axis_dim,
const framework::Tensor* y,
const framework::Tensor* y_grad,
class SoftmaxGradEigen {
public:
void operator()(const DeviceContext& context, const int axis_dim,
const framework::Tensor* y, const framework::Tensor* y_grad,
framework::Tensor* x_grad) {
auto softmax = EigenMatrix<T>::From(*y);
auto softmax_grad = EigenMatrix<T>::From(*y_grad);
......@@ -188,15 +252,49 @@ void SoftmaxGradEigen(const DeviceContext& context, const int axis_dim,
.sum(along_class)
.eval()
.broadcast(one_axis);
logits_grad.device(*context.eigen_device()) = (softmax_grad - dot) * softmax;
}
logits_grad.device(*context.eigen_device()) =
(softmax_grad - dot) * softmax;
}
};
template <typename DeviceContext>
class SoftmaxGradEigen<DeviceContext, platform::float16> {
public:
void operator()(const DeviceContext& context, const int axis_dim,
const framework::Tensor* y, const framework::Tensor* y_grad,
framework::Tensor* x_grad) {
auto softmax = EigenMatrix<platform::float16>::From(*y);
auto softmax_grad = EigenMatrix<platform::float16>::From(*y_grad);
auto logits_grad = EigenMatrix<platform::float16>::From(*x_grad);
constexpr int kBatchDim = 0;
constexpr int kClassDim = 1;
const int batch_size = softmax.dimension(kBatchDim);
const int num_classes = softmax.dimension(kClassDim);
const int num_remain = num_classes / axis_dim;
Eigen::DSizes<int, 1> along_class(kClassDim);
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
Eigen::DSizes<int, 2> one_by_class(1, num_classes);
Eigen::DSizes<int, 3> batch_axis_remain(batch_size, axis_dim, num_remain);
Eigen::DSizes<int, 2> one_axis(1, axis_dim);
auto dot = (softmax * softmax_grad)
.reshape(batch_axis_remain)
.sum(along_class)
.broadcast(one_axis);
logits_grad.device(*context.eigen_device()) =
(softmax_grad - dot) * softmax;
}
};
template <typename DeviceContext, typename T, typename Enable>
void SoftmaxGradFunctor<DeviceContext, T, Enable>::operator()(
const DeviceContext& context, const int axis_dim,
const framework::Tensor* y, const framework::Tensor* y_grad,
framework::Tensor* x_grad) {
SoftmaxGradEigen<DeviceContext, T>(context, axis_dim, y, y_grad, x_grad);
SoftmaxGradEigen<DeviceContext, T>()(context, axis_dim, y, y_grad, x_grad);
}
template <typename DeviceContext, typename T>
......@@ -228,7 +326,8 @@ class SoftmaxGradFunctor<DeviceContext, T, enable_if_CPU<DeviceContext>> {
in_grad += num_classes;
}
} else {
SoftmaxGradEigen<DeviceContext, T>(context, axis_dim, y, y_grad, x_grad);
SoftmaxGradEigen<DeviceContext, T>()(context, axis_dim, y, y_grad,
x_grad);
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册