提交 d3323268 编写于 作者: J Jacek Czaja

- Added unit tests for softmax is_test=True op

test=develop
上级 c1fccc29
......@@ -19,10 +19,10 @@ namespace paddle {
namespace operators {
namespace math {
template class SoftmaxFunctor<platform::CPUDeviceContext, float,true>;
template class SoftmaxFunctor<platform::CPUDeviceContext, float,false>;
template class SoftmaxFunctor<platform::CPUDeviceContext, double,true>;
template class SoftmaxFunctor<platform::CPUDeviceContext, double,false>;
template class SoftmaxFunctor<platform::CPUDeviceContext, float, true>;
template class SoftmaxFunctor<platform::CPUDeviceContext, float, false>;
template class SoftmaxFunctor<platform::CPUDeviceContext, double, true>;
template class SoftmaxFunctor<platform::CPUDeviceContext, double, false>;
template class SoftmaxGradFunctor<platform::CPUDeviceContext, float>;
template class SoftmaxGradFunctor<platform::CPUDeviceContext, double>;
......
......@@ -33,8 +33,8 @@ struct ValueClip {
};
template <typename DeviceContext, typename T, bool is_test>
void SoftmaxFunctor<DeviceContext, T, is_test>::operator()(const DeviceContext& context,
const framework::Tensor* X,
void SoftmaxFunctor<DeviceContext, T, is_test>::operator()(
const DeviceContext& context, const framework::Tensor* X,
framework::Tensor* Y) {
auto logits = EigenMatrix<T>::From(*X);
auto softmax = EigenMatrix<T>::From(*Y);
......@@ -67,8 +67,7 @@ void SoftmaxFunctor<DeviceContext, T, is_test>::operator()(const DeviceContext&
template <typename DeviceContext, typename T>
class SoftmaxFunctor<DeviceContext, T, true> {
void operator()(const DeviceContext& context,
const framework::Tensor* X,
void operator()(const DeviceContext& context, const framework::Tensor* X,
framework::Tensor* Y) {
auto logits = EigenMatrix<T>::From(*X);
auto softmax = EigenMatrix<T>::From(*Y);
......@@ -96,11 +95,9 @@ void operator()(const DeviceContext& context,
.eval()
.reshape(batch_by_one)
.broadcast(one_by_class));
}
}
};
template <typename DeviceContext, typename T>
void SoftmaxGradFunctor<DeviceContext, T>::operator()(
const DeviceContext& context, const framework::Tensor* y,
......
......@@ -36,11 +36,11 @@ class SoftmaxKernel : public framework::OpKernel<T> {
Tensor Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
const bool is_test = context.Attr<bool>("is_test");
if( is_test == true) {
math::SoftmaxFunctor<DeviceContext, T,true>()(
if (is_test == true) {
math::SoftmaxFunctor<DeviceContext, T, true>()(
context.template device_context<DeviceContext>(), &X_2d, &Out_2d);
} else {
math::SoftmaxFunctor<DeviceContext, T,false>()(
math::SoftmaxFunctor<DeviceContext, T, false>()(
context.template device_context<DeviceContext>(), &X_2d, &Out_2d);
}
}
......
......@@ -42,8 +42,8 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
auto& dev_ctx =
context.template device_context<platform::CPUDeviceContext>();
math::SoftmaxFunctor<platform::CPUDeviceContext, T, false>()(dev_ctx, logits,
softmax);
math::SoftmaxFunctor<platform::CPUDeviceContext, T, false>()(
dev_ctx, logits, softmax);
math::CrossEntropyFunctor<platform::CPUDeviceContext, T>()(
dev_ctx, loss, softmax, labels, context.Attr<bool>("soft_label"),
context.Attr<int>("ignore_index"));
......
......@@ -35,6 +35,7 @@ class TestSoftmaxOp(OpTest):
self.op_type = "softmax"
self.use_cudnn = False
self.use_mkldnn = False
self.is_test = False
self.dtype = np.float32
self.init_kernel_type()
self.shape = self.get_x_shape()
......@@ -48,7 +49,8 @@ class TestSoftmaxOp(OpTest):
self.outputs = {'Out': out}
self.attrs = {
'use_cudnn': self.use_cudnn,
'use_mkldnn': self.use_mkldnn
'use_mkldnn': self.use_mkldnn,
'is_test': self.is_test
}
def init_kernel_type(self):
......@@ -144,6 +146,11 @@ class TestSoftmaxFP16CUDNNOp2(TestSoftmaxFP16CUDNNOp):
return [2, 3, 4, 5]
class TestSoftmaxInference(TestSoftmaxOp):
def init_kernel_type(self):
self.is_test = True
class TestSoftmaxMKLDNNOp(TestSoftmaxOp):
def init_kernel_type(self):
self.use_mkldnn = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册