未验证 提交 f3962530 编写于 作者: X xiongkun 提交者: GitHub

[phi] transfer pad infer shape function into phi infer meta (#40158)

* pad infershape

* fix code

* fix

* add set dtype
上级 d35b5b58
......@@ -13,8 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
......@@ -28,37 +30,6 @@ class PadOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Pad");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Pad");
auto x_dim = ctx->GetInputDim("X");
auto& paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
PADDLE_ENFORCE_EQ(
static_cast<int>(paddings.size()), x_dim.size() * 2,
platform::errors::InvalidArgument(
"Size of 'paddings' dimension should be equal to 2 * size of "
"Input(X)'s dimension, but received (size of 'paddings' dimension "
"is) %d vs (2 * size of Input(X)'s dimension is) %d.",
static_cast<int>(paddings.size()), x_dim.size() * 2));
for (size_t i = 0; i < paddings.size(); ++i) {
PADDLE_ENFORCE_GE(paddings[i], 0,
platform::errors::InvalidArgument(
"The element of 'paddings' should >= 0, but "
"received %d for index %d.",
paddings[i], static_cast<int>(i)));
}
std::vector<int64_t> out_dims(x_dim.size());
for (int i = 0; i < x_dim.size(); ++i) {
if ((!ctx->IsRuntime()) && (x_dim[i] == -1)) {
out_dims[i] = -1;
} else {
out_dims[i] = x_dim[i] + paddings[i * 2] + paddings[i * 2 + 1];
}
}
ctx->SetOutputDim("Out", phi::make_ddim(out_dims));
if (out_dims[0] == x_dim[0]) {
// Only pass LoD when the first dimension is equal between
// output and input.
ctx->ShareLoD("X", /*->*/ "Out");
}
}
};
......@@ -160,10 +131,13 @@ class PadOpDoubleGradMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(pad, PadInferShapeFunctor,
PD_INFER_META(phi::PadInferMeta));
REGISTER_OPERATOR(pad, ops::PadOp, ops::PadOpMaker,
ops::PadOpGradMaker<paddle::framework::OpDesc>,
ops::PadOpGradMaker<paddle::imperative::OpBase>);
ops::PadOpGradMaker<paddle::imperative::OpBase>,
PadInferShapeFunctor);
REGISTER_OPERATOR(pad_grad, ops::PadOpGrad,
ops::PadOpDoubleGradMaker<paddle::framework::OpDesc>,
ops::PadOpDoubleGradMaker<paddle::imperative::OpBase>);
......@@ -1124,6 +1124,47 @@ void SizeInferMeta(const MetaTensor& input, MetaTensor* out) {
out->set_dims({1});
}
void PadInferMeta(const MetaTensor& input,
const std::vector<int>& paddings,
float pad_value,
MetaTensor* out,
MetaConfig config) {
auto x_dim = input.dims();
PADDLE_ENFORCE_EQ(
static_cast<int>(paddings.size()),
x_dim.size() * 2,
phi::errors::InvalidArgument(
"Size of 'paddings' dimension should be equal to 2 * size of "
"Input(X)'s dimension, but received (size of 'paddings' dimension "
"is) %d vs (2 * size of Input(X)'s dimension is) %d.",
static_cast<int>(paddings.size()),
x_dim.size() * 2));
for (size_t i = 0; i < paddings.size(); ++i) {
PADDLE_ENFORCE_GE(paddings[i],
0,
phi::errors::InvalidArgument(
"The element of 'paddings' should >= 0, but "
"received %d for index %d.",
paddings[i],
static_cast<int>(i)));
}
std::vector<int64_t> out_dims(x_dim.size());
for (int i = 0; i < x_dim.size(); ++i) {
if ((!config.is_runtime) && (x_dim[i] == -1)) {
out_dims[i] = -1;
} else {
out_dims[i] = x_dim[i] + paddings[i * 2] + paddings[i * 2 + 1];
}
}
out->set_dims(phi::make_ddim(out_dims));
if (out_dims[0] == x_dim[0]) {
// Only pass LoD when the first dimension is equal between
// output and input.
out->share_lod(input);
}
out->set_dtype(input.dtype());
}
void IsfiniteInferMeta(const MetaTensor& x, MetaTensor* out) {
out->set_dims(x.dims());
out->set_dtype(DataType::BOOL);
......
......@@ -163,6 +163,12 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
void SizeInferMeta(const MetaTensor& input, MetaTensor* out);
void PadInferMeta(const MetaTensor& input,
const std::vector<int>& paddings,
float pad_value,
MetaTensor* out,
MetaConfig config = MetaConfig());
void DiagonalInferMeta(
const MetaTensor& input, int offset, int axis1, int axis2, MetaTensor* out);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册