未验证 提交 d35b5b58 编写于 作者: X xiongkun 提交者: GitHub

[phi] [infershape] transfer nll_loss infer shape into phi (#40375)

* transfer nll_loss infershape into phi
上级 9ebe7276
......@@ -14,7 +14,9 @@ limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/ternary.h"
namespace paddle {
namespace operators {
......@@ -23,77 +25,6 @@ class NLLLossOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "NLLLoss");
OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "NLLLoss");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "NLLLoss");
OP_INOUT_CHECK(ctx->HasOutput("Total_weight"), "Output", "Total_weight",
"NLLLoss");
auto x_dims = ctx->GetInputDim("X");
auto label_dims = ctx->GetInputDim("Label");
auto reduction = ctx->Attrs().Get<std::string>("reduction");
PADDLE_ENFORCE_EQ(x_dims.size() == 2 || x_dims.size() == 4, true,
platform::errors::InvalidArgument(
"The tensor rank of Input(X) must be 2 or 4."));
bool contain_unknown_dim = phi::contain_unknown_dim(x_dims) ||
phi::contain_unknown_dim(label_dims);
bool check = ctx->IsRuntime() || !contain_unknown_dim;
if (check) {
PADDLE_ENFORCE_EQ(
x_dims[0], label_dims[0],
platform::errors::InvalidArgument(
"ShapeError: Expected input batch_size to match label batch_size,"
"But received: the Input(x) batch_size is [%s], the Input(label) "
" batch_size is [%s].",
x_dims[0], label_dims[0]));
if (ctx->HasInput("Weight")) {
auto w_dims = ctx->GetInputDim("Weight");
PADDLE_ENFORCE_EQ(w_dims.size(), 1,
platform::errors::InvalidArgument(
"Input(Weight) should be a 1D tensor."));
PADDLE_ENFORCE_EQ(
x_dims[1], w_dims[0],
platform::errors::InvalidArgument(
"Expected input tensor Weight's size should equal "
"to the first dimension of the input tensor X. But received "
"Weight's "
"size is %d, the first dimension of input X is %d",
w_dims[0], x_dims[1]));
}
}
if (x_dims.size() == 2) {
if (reduction == "none") {
ctx->SetOutputDim("Out", {x_dims[0]});
} else {
ctx->SetOutputDim("Out", {1});
}
} else if (x_dims.size() == 4) {
PADDLE_ENFORCE_EQ(label_dims.size(), 3,
platform::errors::InvalidArgument(
"Expected Input(Lable) dimensions=3, received %d.",
label_dims.size()));
auto input0 = x_dims[0];
auto input2 = x_dims[2];
auto input3 = x_dims[3];
auto label0 = label_dims[0];
auto label1 = label_dims[1];
auto label2 = label_dims[2];
PADDLE_ENFORCE_EQ(
input0 == label0 && input2 == label1 && input3 == label2, true,
platform::errors::InvalidArgument("Input(X) tensor shape should "
"match to Input(Label) tensor "
"shape."));
if (reduction == "none") {
ctx->SetOutputDim("Out", {x_dims[0], x_dims[2], x_dims[3]});
} else {
ctx->SetOutputDim("Out", {1});
}
}
ctx->SetOutputDim("Total_weight", {1});
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......@@ -259,8 +190,11 @@ class NLLLossGradMaker : public framework::SingleGradOpMaker<T> {
} // namespace operators
} // namespace paddle
DECLARE_INFER_SHAPE_FUNCTOR(nll_loss, NllLossRawInferShapeFunctor,
PD_INFER_META(phi::NllLossRawInferMeta));
namespace ops = paddle::operators;
REGISTER_OPERATOR(nll_loss, ops::NLLLossOp, ops::NLLLossOpMaker,
ops::NLLLossGradMaker<paddle::framework::OpDesc>,
ops::NLLLossGradMaker<paddle::imperative::OpBase>);
ops::NLLLossGradMaker<paddle::imperative::OpBase>,
NllLossRawInferShapeFunctor);
REGISTER_OPERATOR(nll_loss_grad, ops::NLLLossGradOp);
......@@ -89,6 +89,86 @@ void AddmmInferMeta(const MetaTensor& input,
out->set_dtype(input.dtype());
}
void NllLossRawInferMeta(const MetaTensor& input,
const MetaTensor& label,
paddle::optional<const MetaTensor&> weight,
int64_t ignore_index,
const std::string& reduction,
MetaTensor* out,
MetaTensor* total_weight,
MetaConfig config) {
auto x_dims = input.dims();
auto label_dims = label.dims();
PADDLE_ENFORCE_EQ(x_dims.size() == 2 || x_dims.size() == 4,
true,
phi::errors::InvalidArgument(
"The tensor rank of Input(X) must be 2 or 4."));
bool contain_unknown_dim =
phi::contain_unknown_dim(x_dims) || phi::contain_unknown_dim(label_dims);
bool check = config.is_runtime || !contain_unknown_dim;
if (check) {
PADDLE_ENFORCE_EQ(
x_dims[0],
label_dims[0],
phi::errors::InvalidArgument(
"ShapeError: Expected input batch_size to match label batch_size,"
"But received: the Input(x) batch_size is [%s], the Input(label) "
" batch_size is [%s].",
x_dims[0],
label_dims[0]));
if (weight.get_ptr() != nullptr) {
auto w_dims = weight->dims();
PADDLE_ENFORCE_EQ(
w_dims.size(),
1,
phi::errors::InvalidArgument("Input(Weight) should be a 1D tensor."));
PADDLE_ENFORCE_EQ(
x_dims[1],
w_dims[0],
phi::errors::InvalidArgument(
"Expected input tensor Weight's size should equal "
"to the first dimension of the input tensor X. But received "
"Weight's "
"size is %d, the first dimension of input X is %d",
w_dims[0],
x_dims[1]));
}
}
if (x_dims.size() == 2) {
if (reduction == "none") {
out->set_dims({x_dims[0]});
} else {
out->set_dims({1});
}
} else if (x_dims.size() == 4) {
PADDLE_ENFORCE_EQ(label_dims.size(),
3,
phi::errors::InvalidArgument(
"Expected Input(Lable) dimensions=3, received %d.",
label_dims.size()));
auto input0 = x_dims[0];
auto input2 = x_dims[2];
auto input3 = x_dims[3];
auto label0 = label_dims[0];
auto label1 = label_dims[1];
auto label2 = label_dims[2];
PADDLE_ENFORCE_EQ(
input0 == label0 && input2 == label1 && input3 == label2,
true,
phi::errors::InvalidArgument("Input(X) tensor shape should "
"match to Input(Label) tensor "
"shape."));
if (reduction == "none") {
out->set_dims({x_dims[0], x_dims[2], x_dims[3]});
} else {
out->set_dims({1});
}
}
total_weight->set_dims({1});
out->set_dtype(input.dtype());
total_weight->set_dtype(input.dtype());
}
void ScatterInferMeta(const MetaTensor& x,
const MetaTensor& index,
const MetaTensor& updates,
......
......@@ -56,6 +56,15 @@ void ScatterInferMeta(const MetaTensor& x,
bool overwrite,
MetaTensor* out);
void NllLossRawInferMeta(const MetaTensor& input,
const MetaTensor& label,
paddle::optional<const MetaTensor&> weight,
int64_t ignore_index,
const std::string& reduction,
MetaTensor* out,
MetaTensor* total_weight,
MetaConfig config = MetaConfig());
void ScatterNdAddInferMeta(const MetaTensor& x,
const MetaTensor& index,
const MetaTensor& updates,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册