未验证 提交 a0cb3203 编写于 作者: 0 0x45f 提交者: GitHub

[Phi]Move size, erfinv, pixel_shuffle infershape to phi (#39949)

* move size, erfinv, pixel_shuffle infershape to phi

* fix erfinv infermeta
上级 2c66775b
...@@ -12,7 +12,10 @@ ...@@ -12,7 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -20,14 +23,6 @@ namespace operators { ...@@ -20,14 +23,6 @@ namespace operators {
class ErfinvOp : public framework::OperatorWithKernel { class ErfinvOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "erfinv");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "erfinv");
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out");
}
}; };
class ErfinvOpMaker : public framework::OpProtoAndCheckerMaker { class ErfinvOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -78,10 +73,13 @@ DECLARE_INPLACE_OP_INFERER(ErfinvInplaceInferer, {"X", "Out"}); ...@@ -78,10 +73,13 @@ DECLARE_INPLACE_OP_INFERER(ErfinvInplaceInferer, {"X", "Out"});
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
DELCARE_INFER_SHAPE_FUNCTOR(erfinv, ErfinvInferShapeFunctor,
PT_INFER_META(phi::UnchangedInferMeta));
REGISTER_OPERATOR( REGISTER_OPERATOR(
erfinv, paddle::operators::ErfinvOp, paddle::operators::ErfinvOpMaker, erfinv, paddle::operators::ErfinvOp, paddle::operators::ErfinvOpMaker,
paddle::operators::ErfinvGradMaker<paddle::framework::OpDesc>, paddle::operators::ErfinvGradMaker<paddle::framework::OpDesc>,
paddle::operators::ErfinvGradMaker<paddle::imperative::OpBase>, paddle::operators::ErfinvGradMaker<paddle::imperative::OpBase>,
paddle::operators::ErfinvInplaceInferer); paddle::operators::ErfinvInplaceInferer, ErfinvInferShapeFunctor);
REGISTER_OPERATOR(erfinv_grad, paddle::operators::ErfinvGradOp); REGISTER_OPERATOR(erfinv_grad, paddle::operators::ErfinvGradOp);
...@@ -10,8 +10,11 @@ See the License for the specific language governing permissions and ...@@ -10,8 +10,11 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <memory> #include <memory>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -19,56 +22,6 @@ namespace operators { ...@@ -19,56 +22,6 @@ namespace operators {
class PixelShuffleOp : public framework::OperatorWithKernel { class PixelShuffleOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::NotFound(
"Input(X) of PixelShuffleOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::NotFound(
"Output(Out) of PixelShuffleOp should not be null."));
auto input_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(input_dims.size(), 4,
platform::errors::InvalidArgument(
"Input should be a 4-D tensor of format [N, C, H, W] "
"or [N, H, W, C], but got %u.",
input_dims.size()));
auto upscale_factor = ctx->Attrs().Get<int>("upscale_factor");
const std::string data_format =
ctx->Attrs().Get<std::string>("data_format");
const bool channel_last = (data_format == "NHWC");
if (!channel_last) {
PADDLE_ENFORCE_EQ(
input_dims[1] % (upscale_factor * upscale_factor), 0,
platform::errors::InvalidArgument(
"The square of upscale_factor[%u] should divide the "
"number of channel[%u]",
upscale_factor * upscale_factor, input_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
input_dims[3] % (upscale_factor * upscale_factor), 0,
platform::errors::InvalidArgument(
"The square of upscale_factor[%u] should divide the "
"number of channel[%u]",
upscale_factor * upscale_factor, input_dims[3]));
}
auto output_dims = input_dims;
output_dims[0] = input_dims[0];
if (!channel_last) {
output_dims[1] = input_dims[1] / (upscale_factor * upscale_factor);
output_dims[2] = input_dims[2] * upscale_factor;
output_dims[3] = input_dims[3] * upscale_factor;
} else {
output_dims[1] = input_dims[1] * upscale_factor;
output_dims[2] = input_dims[2] * upscale_factor;
output_dims[3] = input_dims[3] / (upscale_factor * upscale_factor);
}
ctx->SetOutputDim("Out", output_dims);
}
}; };
class PixelShuffleOpMaker : public framework::OpProtoAndCheckerMaker { class PixelShuffleOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -171,9 +124,13 @@ class PixelShuffleGradOp : public framework::OperatorWithKernel { ...@@ -171,9 +124,13 @@ class PixelShuffleGradOp : public framework::OperatorWithKernel {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
DELCARE_INFER_SHAPE_FUNCTOR(pixel_shuffle, PixelShuffleInferShapeFunctor,
PT_INFER_META(phi::PixelShuffleInferMeta));
REGISTER_OPERATOR(pixel_shuffle, ops::PixelShuffleOp, ops::PixelShuffleOpMaker, REGISTER_OPERATOR(pixel_shuffle, ops::PixelShuffleOp, ops::PixelShuffleOpMaker,
ops::PixelShuffleGradMaker<paddle::framework::OpDesc>, ops::PixelShuffleGradMaker<paddle::framework::OpDesc>,
ops::PixelShuffleGradMaker<paddle::imperative::OpBase>); ops::PixelShuffleGradMaker<paddle::imperative::OpBase>,
PixelShuffleInferShapeFunctor);
REGISTER_OPERATOR(pixel_shuffle_grad, ops::PixelShuffleGradOp); REGISTER_OPERATOR(pixel_shuffle_grad, ops::PixelShuffleGradOp);
......
...@@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -20,13 +23,6 @@ namespace operators { ...@@ -20,13 +23,6 @@ namespace operators {
class SizeOp : public framework::OperatorWithKernel { class SizeOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "Size");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Size");
ctx->SetOutputDim("Out", {1});
}
}; };
class SizeOpMaker : public framework::OpProtoAndCheckerMaker { class SizeOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -48,7 +44,10 @@ Return the number of elements in the input. ...@@ -48,7 +44,10 @@ Return the number of elements in the input.
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
DELCARE_INFER_SHAPE_FUNCTOR(size, SizeInferShapeFunctor,
PT_INFER_META(phi::SizeInferMeta));
REGISTER_OPERATOR( REGISTER_OPERATOR(
size, ops::SizeOp, ops::SizeOpMaker, size, ops::SizeOp, ops::SizeOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
SizeInferShapeFunctor);
...@@ -856,6 +856,57 @@ void DiagInferMeta(const MetaTensor& x, ...@@ -856,6 +856,57 @@ void DiagInferMeta(const MetaTensor& x,
} }
} }
void SizeInferMeta(const MetaTensor& input, MetaTensor* out) {
out->set_dtype(DataType::INT64);
out->set_dims({1});
}
void PixelShuffleInferMeta(const MetaTensor& x,
int upscale_factor,
const std::string& data_format,
MetaTensor* out) {
auto input_dims = x.dims();
PADDLE_ENFORCE_EQ(input_dims.size(),
4,
phi::errors::InvalidArgument(
"Input should be a 4-D tensor of format [N, C, H, W] "
"or [N, H, W, C], but got %u.",
input_dims.size()));
const bool channel_last = (data_format == "NHWC");
if (!channel_last) {
PADDLE_ENFORCE_EQ(input_dims[1] % (upscale_factor * upscale_factor),
0,
phi::errors::InvalidArgument(
"The square of upscale_factor[%u] should divide the "
"number of channel[%u]",
upscale_factor * upscale_factor,
input_dims[1]));
} else {
PADDLE_ENFORCE_EQ(input_dims[3] % (upscale_factor * upscale_factor),
0,
phi::errors::InvalidArgument(
"The square of upscale_factor[%u] should divide the "
"number of channel[%u]",
upscale_factor * upscale_factor,
input_dims[3]));
}
auto output_dims = input_dims;
output_dims[0] = input_dims[0];
if (!channel_last) {
output_dims[1] = input_dims[1] / (upscale_factor * upscale_factor);
output_dims[2] = input_dims[2] * upscale_factor;
output_dims[3] = input_dims[3] * upscale_factor;
} else {
output_dims[1] = input_dims[1] * upscale_factor;
output_dims[2] = input_dims[2] * upscale_factor;
output_dims[3] = input_dims[3] / (upscale_factor * upscale_factor);
}
out->set_dtype(x.dtype());
out->set_dims(output_dims);
}
} // namespace phi } // namespace phi
PD_REGISTER_INFER_META_FN(copy_to, phi::CopyToInferMeta); PD_REGISTER_INFER_META_FN(copy_to, phi::CopyToInferMeta);
......
...@@ -129,4 +129,11 @@ void DiagInferMeta(const MetaTensor& x, ...@@ -129,4 +129,11 @@ void DiagInferMeta(const MetaTensor& x,
float padding_value, float padding_value,
MetaTensor* out); MetaTensor* out);
void SizeInferMeta(const MetaTensor& input, MetaTensor* out);
void PixelShuffleInferMeta(const MetaTensor& x,
int upscale_factor,
const std::string& data_format,
MetaTensor* out);
} // namespace phi } // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册