未验证 提交 7f346a76 编写于 作者: Y YuanRisheng 提交者: GitHub

Delete redundant param in SoftmaxFunctor (#46003)

* perfect softmax functor

* fix compile bugs

* fix ci bugs
上级 4f403d3e
...@@ -21,10 +21,8 @@ namespace paddle { ...@@ -21,10 +21,8 @@ namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
template class SoftmaxFunctor<phi::CPUContext, float, true>; template class SoftmaxFunctor<phi::CPUContext, float>;
template class SoftmaxFunctor<phi::CPUContext, float, false>; template class SoftmaxFunctor<phi::CPUContext, double>;
template class SoftmaxFunctor<phi::CPUContext, double, true>;
template class SoftmaxFunctor<phi::CPUContext, double, false>;
template class SoftmaxGradFunctor<phi::CPUContext, float>; template class SoftmaxGradFunctor<phi::CPUContext, float>;
template class SoftmaxGradFunctor<phi::CPUContext, double>; template class SoftmaxGradFunctor<phi::CPUContext, double>;
......
...@@ -156,14 +156,10 @@ template class SoftmaxCUDNNFunctor<double, phi::GPUContext>; ...@@ -156,14 +156,10 @@ template class SoftmaxCUDNNFunctor<double, phi::GPUContext>;
template class SoftmaxGradCUDNNFunctor<double, phi::GPUContext>; template class SoftmaxGradCUDNNFunctor<double, phi::GPUContext>;
#endif #endif
template class SoftmaxFunctor<phi::GPUContext, platform::float16, false>; template class SoftmaxFunctor<phi::GPUContext, platform::float16>;
template class SoftmaxFunctor<phi::GPUContext, platform::float16, true>; template class SoftmaxFunctor<phi::GPUContext, platform::bfloat16>;
template class SoftmaxFunctor<phi::GPUContext, platform::bfloat16, false>; template class SoftmaxFunctor<phi::GPUContext, float>;
template class SoftmaxFunctor<phi::GPUContext, platform::bfloat16, true>; template class SoftmaxFunctor<phi::GPUContext, double>;
template class SoftmaxFunctor<phi::GPUContext, float, false>;
template class SoftmaxFunctor<phi::GPUContext, double, false>;
template class SoftmaxFunctor<phi::GPUContext, float, true>;
template class SoftmaxFunctor<phi::GPUContext, double, true>;
template class SoftmaxGradFunctor<phi::GPUContext, float>; template class SoftmaxGradFunctor<phi::GPUContext, float>;
template class SoftmaxGradFunctor<phi::GPUContext, double>; template class SoftmaxGradFunctor<phi::GPUContext, double>;
template class SoftmaxGradFunctor<phi::GPUContext, platform::float16>; template class SoftmaxGradFunctor<phi::GPUContext, platform::float16>;
......
...@@ -19,10 +19,7 @@ namespace paddle { ...@@ -19,10 +19,7 @@ namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
template <typename DeviceContext, template <typename DeviceContext, typename T, typename Enable = void>
typename T,
bool is_test,
typename Enable = void>
class SoftmaxFunctor { class SoftmaxFunctor {
public: public:
void operator()(const DeviceContext& context, void operator()(const DeviceContext& context,
......
...@@ -42,7 +42,7 @@ struct ValueClip { ...@@ -42,7 +42,7 @@ struct ValueClip {
} }
}; };
template <typename DeviceContext, typename T, bool is_test> template <typename DeviceContext, typename T>
class SoftmaxEigen { class SoftmaxEigen {
public: public:
void operator()(const DeviceContext& context, void operator()(const DeviceContext& context,
...@@ -103,8 +103,8 @@ class SoftmaxEigen { ...@@ -103,8 +103,8 @@ class SoftmaxEigen {
} }
}; };
template <typename DeviceContext, bool is_test> template <typename DeviceContext>
class SoftmaxEigen<DeviceContext, platform::float16, is_test> { class SoftmaxEigen<DeviceContext, platform::float16> {
public: public:
void operator()(const DeviceContext& context, void operator()(const DeviceContext& context,
const int axis_dim, const int axis_dim,
...@@ -161,8 +161,8 @@ class SoftmaxEigen<DeviceContext, platform::float16, is_test> { ...@@ -161,8 +161,8 @@ class SoftmaxEigen<DeviceContext, platform::float16, is_test> {
} }
}; };
template <typename DeviceContext, bool is_test> template <typename DeviceContext>
class SoftmaxEigen<DeviceContext, platform::bfloat16, is_test> { class SoftmaxEigen<DeviceContext, platform::bfloat16> {
public: public:
void operator()(const DeviceContext& context, void operator()(const DeviceContext& context,
const int axis_dim, const int axis_dim,
...@@ -219,21 +219,21 @@ class SoftmaxEigen<DeviceContext, platform::bfloat16, is_test> { ...@@ -219,21 +219,21 @@ class SoftmaxEigen<DeviceContext, platform::bfloat16, is_test> {
} }
}; };
template <typename DeviceContext, typename T, bool is_test, typename Enable> template <typename DeviceContext, typename T, typename Enable>
void SoftmaxFunctor<DeviceContext, T, is_test, Enable>::operator()( void SoftmaxFunctor<DeviceContext, T, Enable>::operator()(
const DeviceContext& context, const DeviceContext& context,
const int axis_dim, const int axis_dim,
const framework::Tensor* X, const framework::Tensor* X,
framework::Tensor* Y) { framework::Tensor* Y) {
SoftmaxEigen<DeviceContext, T, is_test>()(context, axis_dim, X, Y); SoftmaxEigen<DeviceContext, T>()(context, axis_dim, X, Y);
} }
template <class DeviceContext> template <class DeviceContext>
using enable_if_CPU = typename std::enable_if< using enable_if_CPU = typename std::enable_if<
std::is_same<DeviceContext, phi::CPUContext>::value>::type; std::is_same<DeviceContext, phi::CPUContext>::value>::type;
template <typename DeviceContext, typename T, bool is_test> template <typename DeviceContext, typename T>
class SoftmaxFunctor<DeviceContext, T, is_test, enable_if_CPU<DeviceContext>> { class SoftmaxFunctor<DeviceContext, T, enable_if_CPU<DeviceContext>> {
public: public:
void operator()(const DeviceContext& context, void operator()(const DeviceContext& context,
const int axis_dim, const int axis_dim,
...@@ -267,35 +267,11 @@ class SoftmaxFunctor<DeviceContext, T, is_test, enable_if_CPU<DeviceContext>> { ...@@ -267,35 +267,11 @@ class SoftmaxFunctor<DeviceContext, T, is_test, enable_if_CPU<DeviceContext>> {
out_data += num_classes; out_data += num_classes;
} }
} else { } else {
SoftmaxEigen<DeviceContext, T, is_test>()(context, axis_dim, X, Y); SoftmaxEigen<DeviceContext, T>()(context, axis_dim, X, Y);
} }
} }
}; };
template <typename DeviceContext>
class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> {
public:
void operator()(const DeviceContext& context,
const int axis_dim,
const framework::Tensor* X,
framework::Tensor* Y) {
const auto& in_dims = X->dims();
const float* in_data = X->data<float>();
float* out_data = Y->data<float>();
const int kBatchDim = 0;
const int kClassDim = 1;
// 2D data. Batch x C
auto compute_softmax =
jit::KernelFuncs<jit::SoftmaxTuple<float>, platform::CPUPlace>::Cache()
.At(in_dims[kClassDim]);
compute_softmax(in_data,
out_data,
in_dims[kClassDim],
in_dims[kBatchDim],
in_dims[kClassDim] / axis_dim);
}
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class SoftmaxGradEigen { class SoftmaxGradEigen {
public: public:
......
...@@ -119,10 +119,3 @@ struct OneHotGenerator<CPUContext, T> { ...@@ -119,10 +119,3 @@ struct OneHotGenerator<CPUContext, T> {
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
gumbel_softmax, CPU, ALL_LAYOUT, phi::GumbelSoftmaxKernel, float, double) {} gumbel_softmax, CPU, ALL_LAYOUT, phi::GumbelSoftmaxKernel, float, double) {}
PD_REGISTER_KERNEL(gumbel_softmax_infer,
CPU,
ALL_LAYOUT,
phi::GumbelSoftmaxInferKernel,
float,
double) {}
...@@ -170,10 +170,3 @@ struct GumbleNoiseGenerator<GPUContext, T> { ...@@ -170,10 +170,3 @@ struct GumbleNoiseGenerator<GPUContext, T> {
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
gumbel_softmax, GPU, ALL_LAYOUT, phi::GumbelSoftmaxKernel, float, double) {} gumbel_softmax, GPU, ALL_LAYOUT, phi::GumbelSoftmaxKernel, float, double) {}
PD_REGISTER_KERNEL(gumbel_softmax_infer,
GPU,
ALL_LAYOUT,
phi::GumbelSoftmaxInferKernel,
float,
double) {}
...@@ -25,12 +25,4 @@ void GumbelSoftmaxKernel(const Context& dev_ctx, ...@@ -25,12 +25,4 @@ void GumbelSoftmaxKernel(const Context& dev_ctx,
int axis, int axis,
DenseTensor* out); DenseTensor* out);
template <typename T, typename Context>
void GumbelSoftmaxInferKernel(const Context& dev_ctx,
const DenseTensor& x,
float temperature,
bool hard,
int axis,
DenseTensor* out);
} // namespace phi } // namespace phi
...@@ -48,8 +48,7 @@ void GumbelSoftmaxKernelHelper(const Context& ctx, ...@@ -48,8 +48,7 @@ void GumbelSoftmaxKernelHelper(const Context& ctx,
float temperature, float temperature,
bool hard, bool hard,
int axis, int axis,
DenseTensor* out, DenseTensor* out) {
bool is_test) {
const int rank = x.dims().size(); const int rank = x.dims().size();
axis = funcs::CanonicalAxis(axis, rank); axis = funcs::CanonicalAxis(axis, rank);
int axis_dim = x.dims()[axis]; int axis_dim = x.dims()[axis];
...@@ -81,13 +80,8 @@ void GumbelSoftmaxKernelHelper(const Context& ctx, ...@@ -81,13 +80,8 @@ void GumbelSoftmaxKernelHelper(const Context& ctx,
size_to_axis, size_to_axis,
size_from_axis, size_from_axis,
temperature); temperature);
if (is_test) { paddle::operators::math::SoftmaxFunctor<Context, T>()(
paddle::operators::math::SoftmaxFunctor<Context, T, true>()( ctx, axis_dim, &x_noise_2d, &out_2d);
ctx, axis_dim, &x_noise_2d, &out_2d);
} else {
paddle::operators::math::SoftmaxFunctor<Context, T, false>()(
ctx, axis_dim, &x_noise_2d, &out_2d);
}
if (hard) { if (hard) {
OneHotGenerator<Context, T>::Transform(ctx, x, out, axis); OneHotGenerator<Context, T>::Transform(ctx, x, out, axis);
...@@ -101,19 +95,7 @@ void GumbelSoftmaxKernel(const Context& ctx, ...@@ -101,19 +95,7 @@ void GumbelSoftmaxKernel(const Context& ctx,
bool hard, bool hard,
int axis, int axis,
DenseTensor* out) { DenseTensor* out) {
GumbelSoftmaxKernelHelper<T, Context>( GumbelSoftmaxKernelHelper<T, Context>(ctx, x, temperature, hard, axis, out);
ctx, x, temperature, hard, axis, out, false);
}
template <typename T, typename Context>
void GumbelSoftmaxInferKernel(const Context& ctx,
const DenseTensor& x,
float temperature,
bool hard,
int axis,
DenseTensor* out) {
GumbelSoftmaxKernelHelper<T, Context>(
ctx, x, temperature, hard, axis, out, true);
} }
} // namespace phi } // namespace phi
...@@ -40,7 +40,7 @@ void SoftmaxKernel(const Context& dev_ctx, ...@@ -40,7 +40,7 @@ void SoftmaxKernel(const Context& dev_ctx,
DenseTensor X_2d, Out_2d; DenseTensor X_2d, Out_2d;
X_2d.ShareDataWith(x).Resize({n, d}); X_2d.ShareDataWith(x).Resize({n, d});
Out_2d.ShareDataWith(*out).Resize({n, d}); Out_2d.ShareDataWith(*out).Resize({n, d});
paddle::operators::math::SoftmaxFunctor<Context, T, false>()( paddle::operators::math::SoftmaxFunctor<Context, T>()(
dev_ctx, axis_dim, &X_2d, &Out_2d); dev_ctx, axis_dim, &X_2d, &Out_2d);
} }
......
...@@ -18,19 +18,8 @@ namespace phi { ...@@ -18,19 +18,8 @@ namespace phi {
KernelSignature GumbelSoftmaxOpArgumentMapping( KernelSignature GumbelSoftmaxOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
bool is_test = false; return KernelSignature(
if (ctx.HasAttr("is_test")) { "gumbel_softmax", {"X"}, {"temperature", "hard", "axis"}, {"Out"});
is_test = paddle::any_cast<bool>(ctx.Attr("is_test"));
}
if (is_test) {
return KernelSignature("gumbel_softmax_infer",
{"X"},
{"temperature", "hard", "axis"},
{"Out"});
} else {
return KernelSignature(
"gumbel_softmax", {"X"}, {"temperature", "hard", "axis"}, {"Out"});
}
} }
KernelSignature GumbelSoftmaxGradOpArgumentMapping( KernelSignature GumbelSoftmaxGradOpArgumentMapping(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册