未验证 提交 1c205883 编写于 作者: 0 0x45f 提交者: GitHub

move eye, lerp infershape to phi (#40105)

上级 167d511f
......@@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/nullary.h"
namespace paddle {
namespace operators {
......@@ -21,24 +24,6 @@ class EyeOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(Out) of EyeOP should not be null."));
auto num_rows = ctx->Attrs().Get<int64_t>("num_rows");
PADDLE_ENFORCE_EQ(
num_rows >= 0, true,
platform::errors::InvalidArgument(
"The value of Input(num_rows) should be non-negative int."));
auto num_columns = ctx->Attrs().Get<int64_t>("num_columns");
if (num_columns == -1) num_columns = num_rows;
PADDLE_ENFORCE_EQ(
num_columns >= 0, true,
platform::errors::InvalidArgument(
"The value of Input(num_columns) should be non-negative int."));
ctx->SetOutputDim("Out", {num_rows, num_columns});
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......@@ -82,8 +67,11 @@ Return an identity tensor whose shape is [num_rows, num_columns].
} // namespace paddle
namespace ops = paddle::operators;
DELCARE_INFER_SHAPE_FUNCTOR(eye, EyeInferShapeFunctor,
PT_INFER_META(phi::EyeInferMeta));
REGISTER_OPERATOR(
eye, ops::EyeOp, ops::EyeOpMaker, ops::EyeOpVarTypeInference,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
EyeInferShapeFunctor);
......@@ -12,7 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/ternary.h"
namespace paddle {
namespace operators {
......@@ -20,49 +23,6 @@ namespace operators {
class LerpOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "lerp");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "lerp");
OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "lerp");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "lerp");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
auto w_dims = ctx->GetInputDim("Weight");
framework::DDim out_dims;
out_dims = GetOutputDims(x_dims, y_dims);
if (w_dims.size() > 1 || w_dims[0] != 1) {
out_dims = GetOutputDims(out_dims, w_dims);
}
ctx->SetOutputDim("Out", out_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
private:
framework::DDim GetOutputDims(const framework::DDim& s_dims,
const framework::DDim& l_dims) const {
if (s_dims.size() > l_dims.size()) {
return GetOutputDims(l_dims, s_dims);
}
std::vector<int64_t> shapes = phi::vectorize<int64_t>(l_dims);
for (int i = s_dims.size() - 1, j = l_dims.size() - 1; i >= 0; --i, --j) {
int64_t s = s_dims[i];
int64_t l = l_dims[j];
if (s != l) {
if (l == 1) {
shapes[j] = s;
} else if (s != 1) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The shape of tensor a %s:%d must match shape of tensor b "
"%s:%d.",
s_dims.to_str(), i, l_dims.to_str(), j));
}
}
}
return phi::make_ddim(shapes);
}
};
class LerpOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -125,10 +85,12 @@ DECLARE_INPLACE_OP_INFERER(LerpInplaceInferer, {"X", "Out"});
} // namespace operators
} // namespace paddle
DELCARE_INFER_SHAPE_FUNCTOR(lerp, LerpInferShapeFunctor,
PT_INFER_META(phi::LerpInferMeta));
REGISTER_OPERATOR(
lerp, paddle::operators::LerpOp, paddle::operators::LerpOpMaker,
paddle::operators::LerpOpGradMaker<paddle::framework::OpDesc>,
paddle::operators::LerpOpGradMaker<paddle::imperative::OpBase>,
paddle::operators::LerpInplaceInferer);
paddle::operators::LerpInplaceInferer, LerpInferShapeFunctor);
REGISTER_OPERATOR(lerp_grad, paddle::operators::LerpGradOp);
......@@ -32,4 +32,12 @@ void CreateInferMeta(const ScalarArray& shape,
CreateInferMetaBase(shape.GetData(), dtype, DataLayout::NCHW, out);
}
void EyeInferMeta(int64_t num_rows,
int64_t num_columns,
DataType dtype,
MetaTensor* out) {
if (num_columns == -1) num_columns = num_rows;
out->set_dims({num_rows, num_columns});
out->set_dtype(dtype);
}
} // namespace phi
......@@ -35,4 +35,9 @@ void CreateInferMetaBase(const std::vector<int64_t>& shape,
void CreateInferMeta(const ScalarArray& shape, DataType dtype, MetaTensor* out);
void EyeInferMeta(int64_t num_rows,
int64_t num_columns,
DataType dtype,
MetaTensor* out);
} // namespace phi
......@@ -89,4 +89,21 @@ void AddmmInferMeta(const MetaTensor& input,
out->set_dtype(input.dtype());
}
void LerpInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& weight,
MetaTensor* out) {
auto x_dims = x.dims();
auto y_dims = y.dims();
auto w_dims = weight.dims();
DDim out_dims;
out_dims = funcs::GetOutputDims(x_dims, y_dims);
if (w_dims.size() > 1 || w_dims[0] != 1) {
out_dims = funcs::GetOutputDims(out_dims, w_dims);
}
out->set_dims(out_dims);
out->set_dtype(x.dtype());
out->share_lod(x);
}
} // namespace phi
......@@ -37,4 +37,9 @@ void AddmmInferMeta(const MetaTensor& input,
float beta,
MetaTensor* out);
void LerpInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& weight,
MetaTensor* out);
} // namespace phi
......@@ -22,7 +22,7 @@ template <typename T, typename Context>
void EyeKernel(const Context& ctx,
int64_t num_rows,
int64_t num_columns,
int dtype,
DataType dtype,
DenseTensor* out);
} // namespace phi
......@@ -140,5 +140,30 @@ inline bool CheckDims(const DDim &dims_x, const DDim &dims_y) {
return true;
}
inline DDim GetOutputDims(const DDim &s_dims, const DDim &l_dims) {
if (s_dims.size() > l_dims.size()) {
return GetOutputDims(l_dims, s_dims);
}
std::vector<int64_t> shapes = phi::vectorize<int64_t>(l_dims);
for (int i = s_dims.size() - 1, j = l_dims.size() - 1; i >= 0; --i, --j) {
int64_t s = s_dims[i];
int64_t l = l_dims[j];
if (s != l) {
if (l == 1) {
shapes[j] = s;
} else if (s != 1) {
PADDLE_THROW(errors::InvalidArgument(
"The shape of tensor a %s:%d must match shape of tensor b "
"%s:%d.",
s_dims.to_str(),
i,
l_dims.to_str(),
j));
}
}
}
return phi::make_ddim(shapes);
}
} // namespace funcs
} // namespace phi
......@@ -36,7 +36,7 @@ template <typename T, typename Context>
void EyeKernel(const Context& ctx,
int64_t num_rows,
int64_t num_columns,
int dtype,
DataType dtype,
DenseTensor* out) {
auto num = num_columns;
if (num == -1) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册