提交 9b0eae30 编写于 作者: J Jacek Czaja

- Removing partial specialization of sotmax for inference for GPU

test=develop
上级 be80bb4f
...@@ -19,7 +19,8 @@ namespace paddle { ...@@ -19,7 +19,8 @@ namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
template <typename DeviceContext, typename T, bool is_test> template <typename DeviceContext, typename T, bool is_test,
typename Enable = void>
class SoftmaxFunctor { class SoftmaxFunctor {
public: public:
void operator()(const DeviceContext& context, const framework::Tensor* X, void operator()(const DeviceContext& context, const framework::Tensor* X,
......
...@@ -33,8 +33,8 @@ struct ValueClip { ...@@ -33,8 +33,8 @@ struct ValueClip {
} }
}; };
template <typename DeviceContext, typename T, bool is_test> template <typename DeviceContext, typename T, bool is_test, typename Enable>
void SoftmaxFunctor<DeviceContext, T, is_test>::operator()( void SoftmaxFunctor<DeviceContext, T, is_test, Enable>::operator()(
const DeviceContext& context, const framework::Tensor* X, const DeviceContext& context, const framework::Tensor* X,
framework::Tensor* Y) { framework::Tensor* Y) {
auto logits = EigenMatrix<T>::From(*X); auto logits = EigenMatrix<T>::From(*X);
...@@ -66,8 +66,12 @@ void SoftmaxFunctor<DeviceContext, T, is_test>::operator()( ...@@ -66,8 +66,12 @@ void SoftmaxFunctor<DeviceContext, T, is_test>::operator()(
.broadcast(one_by_class)); .broadcast(one_by_class));
} }
template <class DeviceContext>
using enable_if_CPU = typename std::enable_if<
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type;
template <typename DeviceContext> template <typename DeviceContext>
class SoftmaxFunctor<DeviceContext, float, true> { class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> {
void operator()(const DeviceContext& context, const framework::Tensor* X, void operator()(const DeviceContext& context, const framework::Tensor* X,
framework::Tensor* Y) { framework::Tensor* Y) {
auto in_dims = X->dims(); auto in_dims = X->dims();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册