未验证 提交 ec09ef26 编写于 作者: C Chen Weihang 提交者: GitHub

[Phi] Add softmax infermeta functions (#40471)

* rename softmax kernel name

* move softmax infershape

* fix failed test
上级 76f87034
......@@ -24,6 +24,7 @@
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/core/kernel_registry.h"
USE_OP_ITSELF(elementwise_add);
USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN);
......@@ -32,6 +33,8 @@ USE_OP_DEVICE_KERNEL(relu, MKLDNN);
USE_OP_ITSELF(softmax);
USE_OP_DEVICE_KERNEL(softmax, MKLDNN);
PD_DECLARE_KERNEL(softmax, CPU, ALL_LAYOUT);
namespace paddle {
namespace operators {
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
......@@ -23,6 +24,10 @@ limitations under the License. */
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
......@@ -30,30 +35,6 @@ class SoftmaxOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::NotFound("Input(X) of SoftmaxOp is not found."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::NotFound("Output(Out) of SoftmaxOp is not found."));
auto dim_x = ctx->GetInputDim("X");
auto rank_x = dim_x.size();
auto axis = ctx->Attrs().Get<int>("axis");
PADDLE_ENFORCE_GE(axis, -rank_x,
platform::errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], "
"R is the rank of Input(X)."));
PADDLE_ENFORCE_LT(axis, rank_x,
platform::errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], "
"R is the rank of Input(X)."));
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......@@ -168,23 +149,6 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("Out"), true,
platform::errors::InvalidArgument("Input(Out) is not found."));
PADDLE_ENFORCE_EQ(
ctx->HasInput(framework::GradVarName("Out")), true,
platform::errors::InvalidArgument("Input(Out@GRAD) is not found."));
PADDLE_ENFORCE_EQ(
ctx->GetInputDim("Out"),
ctx->GetInputDim(framework::GradVarName("Out")),
platform::errors::InvalidArgument("Input(Out) and its gradients "
"should have a same shape."));
ctx->SetOutputDim(framework::GradVarName("X"),
ctx->GetInputDim(framework::GradVarName("Out")));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......@@ -244,9 +208,14 @@ DECLARE_INPLACE_OP_INFERER(SoftmaxInplaceInferer, {"X", "Out"});
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(softmax, SoftmaxInferShapeFunctor,
PD_INFER_META(phi::SoftmaxInferMeta));
REGISTER_OPERATOR(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker,
ops::SoftmaxOpInferVarType,
ops::SoftmaxOpGradMaker<paddle::framework::OpDesc>,
ops::SoftmaxOpGradMaker<paddle::imperative::OpBase>,
ops::SoftmaxInplaceInferer);
REGISTER_OPERATOR(softmax_grad, ops::SoftmaxOpGrad);
ops::SoftmaxInplaceInferer, SoftmaxInferShapeFunctor);
DECLARE_INFER_SHAPE_FUNCTOR(softmax_grad, SoftmaxGradnferShapeFunctor,
PD_INFER_META(phi::GeneralUnaryGradInferMeta));
REGISTER_OPERATOR(softmax_grad, ops::SoftmaxOpGrad,
SoftmaxGradnferShapeFunctor);
......@@ -64,6 +64,12 @@ void BilinearTensorProductGradInferMeta(const MetaTensor& x,
}
}
void GeneralUnaryGradInferMeta(const MetaTensor& x, MetaTensor* dx) {
if (dx) {
dx->share_meta(x);
}
}
void GeneralBinaryGradInferMeta(const MetaTensor& x,
const MetaTensor& y,
MetaTensor* dx,
......
......@@ -30,6 +30,8 @@ void BilinearTensorProductGradInferMeta(const MetaTensor& x,
MetaTensor* dweight,
MetaTensor* dbias);
void GeneralUnaryGradInferMeta(const MetaTensor& x, MetaTensor* dx);
void GeneralBinaryGradInferMeta(const MetaTensor& x,
const MetaTensor& y,
MetaTensor* dx,
......
......@@ -1409,6 +1409,25 @@ void ShardIndexInferMeta(const MetaTensor& in,
out->set_dtype(in.dtype());
}
void SoftmaxInferMeta(const MetaTensor& x, int axis, MetaTensor* out) {
auto dim_x = x.dims();
auto rank_x = dim_x.size();
PADDLE_ENFORCE_GE(axis,
-rank_x,
phi::errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], "
"R is the rank of Input(X)."));
PADDLE_ENFORCE_LT(axis,
rank_x,
phi::errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], "
"R is the rank of Input(X)."));
out->set_dims(x.dims());
out->set_dtype(x.dtype());
out->share_lod(x);
}
} // namespace phi
PD_REGISTER_INFER_META_FN(copy_to, phi::CopyToInferMeta);
......
......@@ -203,4 +203,6 @@ void ShardIndexInferMeta(const MetaTensor& in,
MetaTensor* out,
MetaConfig config = MetaConfig());
void SoftmaxInferMeta(const MetaTensor& x, int axis, MetaTensor* out);
} // namespace phi
......@@ -19,4 +19,4 @@ limitations under the License. */
#include "paddle/phi/kernels/impl/softmax_kernel_impl.h"
PD_REGISTER_KERNEL(
softmax, CPU, ALL_LAYOUT, phi::SoftmaxRawKernel, float, double) {}
softmax, CPU, ALL_LAYOUT, phi::SoftmaxKernel, float, double) {}
......@@ -23,7 +23,7 @@ limitations under the License. */
PD_REGISTER_KERNEL(softmax,
GPU,
ALL_LAYOUT,
phi::SoftmaxRawKernel,
phi::SoftmaxKernel,
float,
double,
phi::dtype::float16,
......
......@@ -21,10 +21,10 @@ limitations under the License. */
namespace phi {
template <typename T, typename Context>
void SoftmaxRawGPUDNNKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
DenseTensor* out) {
void SoftmaxGPUDNNKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
SoftmaxForwardCUDAKernelDriver<T>(dev_ctx, x, axis, out);
}
......@@ -35,7 +35,7 @@ void SoftmaxRawGPUDNNKernel(const Context& dev_ctx,
PD_REGISTER_KERNEL(softmax,
GPUDNN,
ALL_LAYOUT,
phi::SoftmaxRawGPUDNNKernel,
phi::SoftmaxGPUDNNKernel,
float,
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -44,7 +44,7 @@ PD_REGISTER_KERNEL(softmax,
PD_REGISTER_KERNEL(softmax,
GPUDNN,
ALL_LAYOUT,
phi::SoftmaxRawGPUDNNKernel,
phi::SoftmaxGPUDNNKernel,
float,
double,
phi::dtype::float16,
......@@ -53,7 +53,7 @@ PD_REGISTER_KERNEL(softmax,
PD_REGISTER_KERNEL(softmax,
GPUDNN,
ALL_LAYOUT,
phi::SoftmaxRawGPUDNNKernel,
phi::SoftmaxGPUDNNKernel,
float,
double,
phi::dtype::float16) {}
......
......@@ -22,10 +22,10 @@ limitations under the License. */
namespace phi {
template <typename T, typename Context>
void SoftmaxRawKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
DenseTensor* out) {
void SoftmaxKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
DenseTensor* out) {
const int rank = x.dims().size();
const int calc_axis = phi::funcs::CanonicalAxis(axis, rank);
int axis_dim = x.dims()[calc_axis];
......
......@@ -19,20 +19,10 @@ limitations under the License. */
namespace phi {
template <typename T, typename Context>
void SoftmaxRawKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
DenseTensor* out);
template <typename T, typename Context>
void SoftmaxKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
DataType dtype,
DenseTensor* out) {
auto cast_x = phi::Cast<T, Context>(dev_ctx, x, dtype);
phi::SoftmaxRawKernel<T, Context>(dev_ctx, axis, out);
}
DenseTensor* out);
} // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册