未验证 提交 1c205883 编写于 作者: 0 0x45f 提交者: GitHub

move eye, lerp infershape to phi (#40105)

上级 167d511f
...@@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/nullary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -21,24 +24,6 @@ class EyeOp : public framework::OperatorWithKernel { ...@@ -21,24 +24,6 @@ class EyeOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(Out) of EyeOP should not be null."));
auto num_rows = ctx->Attrs().Get<int64_t>("num_rows");
PADDLE_ENFORCE_EQ(
num_rows >= 0, true,
platform::errors::InvalidArgument(
"The value of Input(num_rows) should be non-negative int."));
auto num_columns = ctx->Attrs().Get<int64_t>("num_columns");
if (num_columns == -1) num_columns = num_rows;
PADDLE_ENFORCE_EQ(
num_columns >= 0, true,
platform::errors::InvalidArgument(
"The value of Input(num_columns) should be non-negative int."));
ctx->SetOutputDim("Out", {num_rows, num_columns});
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
...@@ -82,8 +67,11 @@ Return an identity tensor whose shape is [num_rows, num_columns]. ...@@ -82,8 +67,11 @@ Return an identity tensor whose shape is [num_rows, num_columns].
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
DELCARE_INFER_SHAPE_FUNCTOR(eye, EyeInferShapeFunctor,
PT_INFER_META(phi::EyeInferMeta));
REGISTER_OPERATOR( REGISTER_OPERATOR(
eye, ops::EyeOp, ops::EyeOpMaker, ops::EyeOpVarTypeInference, eye, ops::EyeOp, ops::EyeOpMaker, ops::EyeOpVarTypeInference,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
EyeInferShapeFunctor);
...@@ -12,7 +12,10 @@ ...@@ -12,7 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/ternary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -20,49 +23,6 @@ namespace operators { ...@@ -20,49 +23,6 @@ namespace operators {
class LerpOp : public framework::OperatorWithKernel { class LerpOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "lerp");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "lerp");
OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "lerp");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "lerp");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
auto w_dims = ctx->GetInputDim("Weight");
framework::DDim out_dims;
out_dims = GetOutputDims(x_dims, y_dims);
if (w_dims.size() > 1 || w_dims[0] != 1) {
out_dims = GetOutputDims(out_dims, w_dims);
}
ctx->SetOutputDim("Out", out_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
private:
framework::DDim GetOutputDims(const framework::DDim& s_dims,
const framework::DDim& l_dims) const {
if (s_dims.size() > l_dims.size()) {
return GetOutputDims(l_dims, s_dims);
}
std::vector<int64_t> shapes = phi::vectorize<int64_t>(l_dims);
for (int i = s_dims.size() - 1, j = l_dims.size() - 1; i >= 0; --i, --j) {
int64_t s = s_dims[i];
int64_t l = l_dims[j];
if (s != l) {
if (l == 1) {
shapes[j] = s;
} else if (s != 1) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The shape of tensor a %s:%d must match shape of tensor b "
"%s:%d.",
s_dims.to_str(), i, l_dims.to_str(), j));
}
}
}
return phi::make_ddim(shapes);
}
}; };
class LerpOpMaker : public framework::OpProtoAndCheckerMaker { class LerpOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -125,10 +85,12 @@ DECLARE_INPLACE_OP_INFERER(LerpInplaceInferer, {"X", "Out"}); ...@@ -125,10 +85,12 @@ DECLARE_INPLACE_OP_INFERER(LerpInplaceInferer, {"X", "Out"});
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
DELCARE_INFER_SHAPE_FUNCTOR(lerp, LerpInferShapeFunctor,
PT_INFER_META(phi::LerpInferMeta));
REGISTER_OPERATOR( REGISTER_OPERATOR(
lerp, paddle::operators::LerpOp, paddle::operators::LerpOpMaker, lerp, paddle::operators::LerpOp, paddle::operators::LerpOpMaker,
paddle::operators::LerpOpGradMaker<paddle::framework::OpDesc>, paddle::operators::LerpOpGradMaker<paddle::framework::OpDesc>,
paddle::operators::LerpOpGradMaker<paddle::imperative::OpBase>, paddle::operators::LerpOpGradMaker<paddle::imperative::OpBase>,
paddle::operators::LerpInplaceInferer); paddle::operators::LerpInplaceInferer, LerpInferShapeFunctor);
REGISTER_OPERATOR(lerp_grad, paddle::operators::LerpGradOp); REGISTER_OPERATOR(lerp_grad, paddle::operators::LerpGradOp);
...@@ -32,4 +32,12 @@ void CreateInferMeta(const ScalarArray& shape, ...@@ -32,4 +32,12 @@ void CreateInferMeta(const ScalarArray& shape,
CreateInferMetaBase(shape.GetData(), dtype, DataLayout::NCHW, out); CreateInferMetaBase(shape.GetData(), dtype, DataLayout::NCHW, out);
} }
void EyeInferMeta(int64_t num_rows,
int64_t num_columns,
DataType dtype,
MetaTensor* out) {
if (num_columns == -1) num_columns = num_rows;
out->set_dims({num_rows, num_columns});
out->set_dtype(dtype);
}
} // namespace phi } // namespace phi
...@@ -35,4 +35,9 @@ void CreateInferMetaBase(const std::vector<int64_t>& shape, ...@@ -35,4 +35,9 @@ void CreateInferMetaBase(const std::vector<int64_t>& shape,
void CreateInferMeta(const ScalarArray& shape, DataType dtype, MetaTensor* out); void CreateInferMeta(const ScalarArray& shape, DataType dtype, MetaTensor* out);
void EyeInferMeta(int64_t num_rows,
int64_t num_columns,
DataType dtype,
MetaTensor* out);
} // namespace phi } // namespace phi
...@@ -89,4 +89,21 @@ void AddmmInferMeta(const MetaTensor& input, ...@@ -89,4 +89,21 @@ void AddmmInferMeta(const MetaTensor& input,
out->set_dtype(input.dtype()); out->set_dtype(input.dtype());
} }
void LerpInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& weight,
MetaTensor* out) {
auto x_dims = x.dims();
auto y_dims = y.dims();
auto w_dims = weight.dims();
DDim out_dims;
out_dims = funcs::GetOutputDims(x_dims, y_dims);
if (w_dims.size() > 1 || w_dims[0] != 1) {
out_dims = funcs::GetOutputDims(out_dims, w_dims);
}
out->set_dims(out_dims);
out->set_dtype(x.dtype());
out->share_lod(x);
}
} // namespace phi } // namespace phi
...@@ -37,4 +37,9 @@ void AddmmInferMeta(const MetaTensor& input, ...@@ -37,4 +37,9 @@ void AddmmInferMeta(const MetaTensor& input,
float beta, float beta,
MetaTensor* out); MetaTensor* out);
void LerpInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& weight,
MetaTensor* out);
} // namespace phi } // namespace phi
...@@ -22,7 +22,7 @@ template <typename T, typename Context> ...@@ -22,7 +22,7 @@ template <typename T, typename Context>
void EyeKernel(const Context& ctx, void EyeKernel(const Context& ctx,
int64_t num_rows, int64_t num_rows,
int64_t num_columns, int64_t num_columns,
int dtype, DataType dtype,
DenseTensor* out); DenseTensor* out);
} // namespace phi } // namespace phi
...@@ -140,5 +140,30 @@ inline bool CheckDims(const DDim &dims_x, const DDim &dims_y) { ...@@ -140,5 +140,30 @@ inline bool CheckDims(const DDim &dims_x, const DDim &dims_y) {
return true; return true;
} }
inline DDim GetOutputDims(const DDim &s_dims, const DDim &l_dims) {
if (s_dims.size() > l_dims.size()) {
return GetOutputDims(l_dims, s_dims);
}
std::vector<int64_t> shapes = phi::vectorize<int64_t>(l_dims);
for (int i = s_dims.size() - 1, j = l_dims.size() - 1; i >= 0; --i, --j) {
int64_t s = s_dims[i];
int64_t l = l_dims[j];
if (s != l) {
if (l == 1) {
shapes[j] = s;
} else if (s != 1) {
PADDLE_THROW(errors::InvalidArgument(
"The shape of tensor a %s:%d must match shape of tensor b "
"%s:%d.",
s_dims.to_str(),
i,
l_dims.to_str(),
j));
}
}
}
return phi::make_ddim(shapes);
}
} // namespace funcs } // namespace funcs
} // namespace phi } // namespace phi
...@@ -36,7 +36,7 @@ template <typename T, typename Context> ...@@ -36,7 +36,7 @@ template <typename T, typename Context>
void EyeKernel(const Context& ctx, void EyeKernel(const Context& ctx,
int64_t num_rows, int64_t num_rows,
int64_t num_columns, int64_t num_columns,
int dtype, DataType dtype,
DenseTensor* out) { DenseTensor* out) {
auto num = num_columns; auto num = num_columns;
if (num == -1) { if (num == -1) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册