未验证 提交 7f346a76 编写于 作者: Y YuanRisheng 提交者: GitHub

Delete redundant param in SoftmaxFunctor (#46003)

* perfect softmax functor

* fix compile bugs

* fix ci bugs
上级 4f403d3e
......@@ -21,10 +21,8 @@ namespace paddle {
namespace operators {
namespace math {
template class SoftmaxFunctor<phi::CPUContext, float, true>;
template class SoftmaxFunctor<phi::CPUContext, float, false>;
template class SoftmaxFunctor<phi::CPUContext, double, true>;
template class SoftmaxFunctor<phi::CPUContext, double, false>;
template class SoftmaxFunctor<phi::CPUContext, float>;
template class SoftmaxFunctor<phi::CPUContext, double>;
template class SoftmaxGradFunctor<phi::CPUContext, float>;
template class SoftmaxGradFunctor<phi::CPUContext, double>;
......
......@@ -156,14 +156,10 @@ template class SoftmaxCUDNNFunctor<double, phi::GPUContext>;
template class SoftmaxGradCUDNNFunctor<double, phi::GPUContext>;
#endif
template class SoftmaxFunctor<phi::GPUContext, platform::float16, false>;
template class SoftmaxFunctor<phi::GPUContext, platform::float16, true>;
template class SoftmaxFunctor<phi::GPUContext, platform::bfloat16, false>;
template class SoftmaxFunctor<phi::GPUContext, platform::bfloat16, true>;
template class SoftmaxFunctor<phi::GPUContext, float, false>;
template class SoftmaxFunctor<phi::GPUContext, double, false>;
template class SoftmaxFunctor<phi::GPUContext, float, true>;
template class SoftmaxFunctor<phi::GPUContext, double, true>;
template class SoftmaxFunctor<phi::GPUContext, platform::float16>;
template class SoftmaxFunctor<phi::GPUContext, platform::bfloat16>;
template class SoftmaxFunctor<phi::GPUContext, float>;
template class SoftmaxFunctor<phi::GPUContext, double>;
template class SoftmaxGradFunctor<phi::GPUContext, float>;
template class SoftmaxGradFunctor<phi::GPUContext, double>;
template class SoftmaxGradFunctor<phi::GPUContext, platform::float16>;
......
......@@ -19,10 +19,7 @@ namespace paddle {
namespace operators {
namespace math {
template <typename DeviceContext,
typename T,
bool is_test,
typename Enable = void>
template <typename DeviceContext, typename T, typename Enable = void>
class SoftmaxFunctor {
public:
void operator()(const DeviceContext& context,
......
......@@ -42,7 +42,7 @@ struct ValueClip {
}
};
template <typename DeviceContext, typename T, bool is_test>
template <typename DeviceContext, typename T>
class SoftmaxEigen {
public:
void operator()(const DeviceContext& context,
......@@ -103,8 +103,8 @@ class SoftmaxEigen {
}
};
template <typename DeviceContext, bool is_test>
class SoftmaxEigen<DeviceContext, platform::float16, is_test> {
template <typename DeviceContext>
class SoftmaxEigen<DeviceContext, platform::float16> {
public:
void operator()(const DeviceContext& context,
const int axis_dim,
......@@ -161,8 +161,8 @@ class SoftmaxEigen<DeviceContext, platform::float16, is_test> {
}
};
template <typename DeviceContext, bool is_test>
class SoftmaxEigen<DeviceContext, platform::bfloat16, is_test> {
template <typename DeviceContext>
class SoftmaxEigen<DeviceContext, platform::bfloat16> {
public:
void operator()(const DeviceContext& context,
const int axis_dim,
......@@ -219,21 +219,21 @@ class SoftmaxEigen<DeviceContext, platform::bfloat16, is_test> {
}
};
template <typename DeviceContext, typename T, bool is_test, typename Enable>
void SoftmaxFunctor<DeviceContext, T, is_test, Enable>::operator()(
template <typename DeviceContext, typename T, typename Enable>
void SoftmaxFunctor<DeviceContext, T, Enable>::operator()(
const DeviceContext& context,
const int axis_dim,
const framework::Tensor* X,
framework::Tensor* Y) {
SoftmaxEigen<DeviceContext, T, is_test>()(context, axis_dim, X, Y);
SoftmaxEigen<DeviceContext, T>()(context, axis_dim, X, Y);
}
template <class DeviceContext>
using enable_if_CPU = typename std::enable_if<
std::is_same<DeviceContext, phi::CPUContext>::value>::type;
template <typename DeviceContext, typename T, bool is_test>
class SoftmaxFunctor<DeviceContext, T, is_test, enable_if_CPU<DeviceContext>> {
template <typename DeviceContext, typename T>
class SoftmaxFunctor<DeviceContext, T, enable_if_CPU<DeviceContext>> {
public:
void operator()(const DeviceContext& context,
const int axis_dim,
......@@ -267,35 +267,11 @@ class SoftmaxFunctor<DeviceContext, T, is_test, enable_if_CPU<DeviceContext>> {
out_data += num_classes;
}
} else {
SoftmaxEigen<DeviceContext, T, is_test>()(context, axis_dim, X, Y);
SoftmaxEigen<DeviceContext, T>()(context, axis_dim, X, Y);
}
}
};
template <typename DeviceContext>
class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> {
public:
void operator()(const DeviceContext& context,
const int axis_dim,
const framework::Tensor* X,
framework::Tensor* Y) {
const auto& in_dims = X->dims();
const float* in_data = X->data<float>();
float* out_data = Y->data<float>();
const int kBatchDim = 0;
const int kClassDim = 1;
// 2D data. Batch x C
auto compute_softmax =
jit::KernelFuncs<jit::SoftmaxTuple<float>, platform::CPUPlace>::Cache()
.At(in_dims[kClassDim]);
compute_softmax(in_data,
out_data,
in_dims[kClassDim],
in_dims[kBatchDim],
in_dims[kClassDim] / axis_dim);
}
};
template <typename DeviceContext, typename T>
class SoftmaxGradEigen {
public:
......
......@@ -119,10 +119,3 @@ struct OneHotGenerator<CPUContext, T> {
PD_REGISTER_KERNEL(
gumbel_softmax, CPU, ALL_LAYOUT, phi::GumbelSoftmaxKernel, float, double) {}
PD_REGISTER_KERNEL(gumbel_softmax_infer,
CPU,
ALL_LAYOUT,
phi::GumbelSoftmaxInferKernel,
float,
double) {}
......@@ -170,10 +170,3 @@ struct GumbleNoiseGenerator<GPUContext, T> {
PD_REGISTER_KERNEL(
gumbel_softmax, GPU, ALL_LAYOUT, phi::GumbelSoftmaxKernel, float, double) {}
PD_REGISTER_KERNEL(gumbel_softmax_infer,
GPU,
ALL_LAYOUT,
phi::GumbelSoftmaxInferKernel,
float,
double) {}
......@@ -25,12 +25,4 @@ void GumbelSoftmaxKernel(const Context& dev_ctx,
int axis,
DenseTensor* out);
template <typename T, typename Context>
void GumbelSoftmaxInferKernel(const Context& dev_ctx,
const DenseTensor& x,
float temperature,
bool hard,
int axis,
DenseTensor* out);
} // namespace phi
......@@ -48,8 +48,7 @@ void GumbelSoftmaxKernelHelper(const Context& ctx,
float temperature,
bool hard,
int axis,
DenseTensor* out,
bool is_test) {
DenseTensor* out) {
const int rank = x.dims().size();
axis = funcs::CanonicalAxis(axis, rank);
int axis_dim = x.dims()[axis];
......@@ -81,13 +80,8 @@ void GumbelSoftmaxKernelHelper(const Context& ctx,
size_to_axis,
size_from_axis,
temperature);
if (is_test) {
paddle::operators::math::SoftmaxFunctor<Context, T, true>()(
ctx, axis_dim, &x_noise_2d, &out_2d);
} else {
paddle::operators::math::SoftmaxFunctor<Context, T, false>()(
ctx, axis_dim, &x_noise_2d, &out_2d);
}
paddle::operators::math::SoftmaxFunctor<Context, T>()(
ctx, axis_dim, &x_noise_2d, &out_2d);
if (hard) {
OneHotGenerator<Context, T>::Transform(ctx, x, out, axis);
......@@ -101,19 +95,7 @@ void GumbelSoftmaxKernel(const Context& ctx,
bool hard,
int axis,
DenseTensor* out) {
GumbelSoftmaxKernelHelper<T, Context>(
ctx, x, temperature, hard, axis, out, false);
}
template <typename T, typename Context>
void GumbelSoftmaxInferKernel(const Context& ctx,
const DenseTensor& x,
float temperature,
bool hard,
int axis,
DenseTensor* out) {
GumbelSoftmaxKernelHelper<T, Context>(
ctx, x, temperature, hard, axis, out, true);
GumbelSoftmaxKernelHelper<T, Context>(ctx, x, temperature, hard, axis, out);
}
} // namespace phi
......@@ -40,7 +40,7 @@ void SoftmaxKernel(const Context& dev_ctx,
DenseTensor X_2d, Out_2d;
X_2d.ShareDataWith(x).Resize({n, d});
Out_2d.ShareDataWith(*out).Resize({n, d});
paddle::operators::math::SoftmaxFunctor<Context, T, false>()(
paddle::operators::math::SoftmaxFunctor<Context, T>()(
dev_ctx, axis_dim, &X_2d, &Out_2d);
}
......
......@@ -18,19 +18,8 @@ namespace phi {
KernelSignature GumbelSoftmaxOpArgumentMapping(
const ArgumentMappingContext& ctx) {
bool is_test = false;
if (ctx.HasAttr("is_test")) {
is_test = paddle::any_cast<bool>(ctx.Attr("is_test"));
}
if (is_test) {
return KernelSignature("gumbel_softmax_infer",
{"X"},
{"temperature", "hard", "axis"},
{"Out"});
} else {
return KernelSignature(
"gumbel_softmax", {"X"}, {"temperature", "hard", "axis"}, {"Out"});
}
return KernelSignature(
"gumbel_softmax", {"X"}, {"temperature", "hard", "axis"}, {"Out"});
}
KernelSignature GumbelSoftmaxGradOpArgumentMapping(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册