提交 d3323268 编写于 作者: J Jacek Czaja

- Added unit tests for softmax is_test=True op

test=develop
上级 c1fccc29
...@@ -19,10 +19,10 @@ namespace paddle { ...@@ -19,10 +19,10 @@ namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
template class SoftmaxFunctor<platform::CPUDeviceContext, float,true>; template class SoftmaxFunctor<platform::CPUDeviceContext, float, true>;
template class SoftmaxFunctor<platform::CPUDeviceContext, float,false>; template class SoftmaxFunctor<platform::CPUDeviceContext, float, false>;
template class SoftmaxFunctor<platform::CPUDeviceContext, double,true>; template class SoftmaxFunctor<platform::CPUDeviceContext, double, true>;
template class SoftmaxFunctor<platform::CPUDeviceContext, double,false>; template class SoftmaxFunctor<platform::CPUDeviceContext, double, false>;
template class SoftmaxGradFunctor<platform::CPUDeviceContext, float>; template class SoftmaxGradFunctor<platform::CPUDeviceContext, float>;
template class SoftmaxGradFunctor<platform::CPUDeviceContext, double>; template class SoftmaxGradFunctor<platform::CPUDeviceContext, double>;
......
...@@ -33,9 +33,9 @@ struct ValueClip { ...@@ -33,9 +33,9 @@ struct ValueClip {
}; };
template <typename DeviceContext, typename T, bool is_test> template <typename DeviceContext, typename T, bool is_test>
void SoftmaxFunctor<DeviceContext, T, is_test>::operator()(const DeviceContext& context, void SoftmaxFunctor<DeviceContext, T, is_test>::operator()(
const framework::Tensor* X, const DeviceContext& context, const framework::Tensor* X,
framework::Tensor* Y) { framework::Tensor* Y) {
auto logits = EigenMatrix<T>::From(*X); auto logits = EigenMatrix<T>::From(*X);
auto softmax = EigenMatrix<T>::From(*Y); auto softmax = EigenMatrix<T>::From(*Y);
...@@ -67,40 +67,37 @@ void SoftmaxFunctor<DeviceContext, T, is_test>::operator()(const DeviceContext& ...@@ -67,40 +67,37 @@ void SoftmaxFunctor<DeviceContext, T, is_test>::operator()(const DeviceContext&
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class SoftmaxFunctor<DeviceContext, T, true> { class SoftmaxFunctor<DeviceContext, T, true> {
void operator()(const DeviceContext& context, void operator()(const DeviceContext& context, const framework::Tensor* X,
const framework::Tensor* X, framework::Tensor* Y) {
framework::Tensor* Y) { auto logits = EigenMatrix<T>::From(*X);
auto logits = EigenMatrix<T>::From(*X); auto softmax = EigenMatrix<T>::From(*Y);
auto softmax = EigenMatrix<T>::From(*Y);
const int kBatchDim = 0;
const int kBatchDim = 0; const int kClassDim = 1;
const int kClassDim = 1;
const int batch_size = logits.dimension(kBatchDim);
const int batch_size = logits.dimension(kBatchDim); const int num_classes = logits.dimension(kClassDim);
const int num_classes = logits.dimension(kClassDim);
Eigen::DSizes<int, 1> along_class(kClassDim);
Eigen::DSizes<int, 1> along_class(kClassDim); Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1); Eigen::DSizes<int, 2> one_by_class(1, num_classes);
Eigen::DSizes<int, 2> one_by_class(1, num_classes);
auto shifted_logits = (logits -
auto shifted_logits = (logits - logits.maximum(along_class)
logits.maximum(along_class) .eval()
.eval() .reshape(batch_by_one)
.reshape(batch_by_one) .broadcast(one_by_class));
.broadcast(one_by_class));
softmax.device(*context.eigen_device()) = shifted_logits.exp();
softmax.device(*context.eigen_device()) = shifted_logits.exp(); softmax.device(*context.eigen_device()) = (softmax *
softmax.device(*context.eigen_device()) = (softmax * softmax.sum(along_class)
softmax.sum(along_class) .inverse()
.inverse() .eval()
.eval() .reshape(batch_by_one)
.reshape(batch_by_one) .broadcast(one_by_class));
.broadcast(one_by_class)); }
}
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
void SoftmaxGradFunctor<DeviceContext, T>::operator()( void SoftmaxGradFunctor<DeviceContext, T>::operator()(
const DeviceContext& context, const framework::Tensor* y, const DeviceContext& context, const framework::Tensor* y,
......
...@@ -36,11 +36,11 @@ class SoftmaxKernel : public framework::OpKernel<T> { ...@@ -36,11 +36,11 @@ class SoftmaxKernel : public framework::OpKernel<T> {
Tensor Out_2d = framework::ReshapeToMatrix(*Out, rank - 1); Tensor Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
const bool is_test = context.Attr<bool>("is_test"); const bool is_test = context.Attr<bool>("is_test");
if( is_test == true) { if (is_test == true) {
math::SoftmaxFunctor<DeviceContext, T,true>()( math::SoftmaxFunctor<DeviceContext, T, true>()(
context.template device_context<DeviceContext>(), &X_2d, &Out_2d); context.template device_context<DeviceContext>(), &X_2d, &Out_2d);
} else { } else {
math::SoftmaxFunctor<DeviceContext, T,false>()( math::SoftmaxFunctor<DeviceContext, T, false>()(
context.template device_context<DeviceContext>(), &X_2d, &Out_2d); context.template device_context<DeviceContext>(), &X_2d, &Out_2d);
} }
} }
......
...@@ -42,8 +42,8 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> { ...@@ -42,8 +42,8 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
auto& dev_ctx = auto& dev_ctx =
context.template device_context<platform::CPUDeviceContext>(); context.template device_context<platform::CPUDeviceContext>();
math::SoftmaxFunctor<platform::CPUDeviceContext, T, false>()(dev_ctx, logits, math::SoftmaxFunctor<platform::CPUDeviceContext, T, false>()(
softmax); dev_ctx, logits, softmax);
math::CrossEntropyFunctor<platform::CPUDeviceContext, T>()( math::CrossEntropyFunctor<platform::CPUDeviceContext, T>()(
dev_ctx, loss, softmax, labels, context.Attr<bool>("soft_label"), dev_ctx, loss, softmax, labels, context.Attr<bool>("soft_label"),
context.Attr<int>("ignore_index")); context.Attr<int>("ignore_index"));
......
...@@ -35,6 +35,7 @@ class TestSoftmaxOp(OpTest): ...@@ -35,6 +35,7 @@ class TestSoftmaxOp(OpTest):
self.op_type = "softmax" self.op_type = "softmax"
self.use_cudnn = False self.use_cudnn = False
self.use_mkldnn = False self.use_mkldnn = False
self.is_test = False
self.dtype = np.float32 self.dtype = np.float32
self.init_kernel_type() self.init_kernel_type()
self.shape = self.get_x_shape() self.shape = self.get_x_shape()
...@@ -48,7 +49,8 @@ class TestSoftmaxOp(OpTest): ...@@ -48,7 +49,8 @@ class TestSoftmaxOp(OpTest):
self.outputs = {'Out': out} self.outputs = {'Out': out}
self.attrs = { self.attrs = {
'use_cudnn': self.use_cudnn, 'use_cudnn': self.use_cudnn,
'use_mkldnn': self.use_mkldnn 'use_mkldnn': self.use_mkldnn,
'is_test': self.is_test
} }
def init_kernel_type(self): def init_kernel_type(self):
...@@ -144,6 +146,11 @@ class TestSoftmaxFP16CUDNNOp2(TestSoftmaxFP16CUDNNOp): ...@@ -144,6 +146,11 @@ class TestSoftmaxFP16CUDNNOp2(TestSoftmaxFP16CUDNNOp):
return [2, 3, 4, 5] return [2, 3, 4, 5]
class TestSoftmaxInference(TestSoftmaxOp):
def init_kernel_type(self):
self.is_test = True
class TestSoftmaxMKLDNNOp(TestSoftmaxOp): class TestSoftmaxMKLDNNOp(TestSoftmaxOp):
def init_kernel_type(self): def init_kernel_type(self):
self.use_mkldnn = True self.use_mkldnn = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册