未验证 提交 6d205516 编写于 作者: X xiongkun 提交者: GitHub

tranfer cumprod and kldiv_loss infershape to phi (#40575)

上级 c7637700
...@@ -12,8 +12,10 @@ ...@@ -12,8 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -21,14 +23,6 @@ namespace operators { ...@@ -21,14 +23,6 @@ namespace operators {
class CumprodOp : public framework::OperatorWithKernel { class CumprodOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Cumprod");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Cumprod");
ctx->ShareDim("X", "Out");
ctx->ShareLoD("X", "Out");
}
}; };
class CumprodOpMaker : public framework::OpProtoAndCheckerMaker { class CumprodOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -82,9 +76,12 @@ class CumprodGradOp : public framework::OperatorWithKernel { ...@@ -82,9 +76,12 @@ class CumprodGradOp : public framework::OperatorWithKernel {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(cumprod, CumprodInferShapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));
REGISTER_OPERATOR(cumprod, ops::CumprodOp, ops::CumprodOpMaker, REGISTER_OPERATOR(cumprod, ops::CumprodOp, ops::CumprodOpMaker,
ops::CumprodGradOpMaker<paddle::framework::OpDesc>, ops::CumprodGradOpMaker<paddle::framework::OpDesc>,
ops::CumprodGradOpMaker<paddle::imperative::OpBase>); ops::CumprodGradOpMaker<paddle::imperative::OpBase>,
CumprodInferShapeFunctor);
REGISTER_OPERATOR(cumprod_grad, ops::CumprodGradOp); REGISTER_OPERATOR(cumprod_grad, ops::CumprodGradOp);
...@@ -11,7 +11,9 @@ ...@@ -11,7 +11,9 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -21,44 +23,6 @@ using framework::Tensor; ...@@ -21,44 +23,6 @@ using framework::Tensor;
class KLDivLossOp : public framework::OperatorWithKernel { class KLDivLossOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "KLDivLoss");
OP_INOUT_CHECK(ctx->HasInput("Target"), "Input", "Target", "KLDivLoss");
OP_INOUT_CHECK(ctx->HasOutput("Loss"), "Output", "Loss", "KLDivLoss");
auto dim_x = ctx->GetInputDim("X");
auto dim_target = ctx->GetInputDim("Target");
PADDLE_ENFORCE_EQ(dim_x.size(), dim_target.size(),
platform::errors::InvalidArgument(
"Input(X) rank and Input(Target) rank should be "
"same, but received X rank(%d) != Target rank(%d)",
dim_x.size(), dim_target.size()));
for (int i = 0; i < dim_x.size(); i++) {
if (ctx->IsRuntime() || (dim_x[i] > 0 && dim_target[i] > 0)) {
PADDLE_ENFORCE_EQ(
dim_x[i], dim_target[i],
platform::errors::InvalidArgument(
"Input(X) and Input(Target) should in same shape. but received "
"X dimension[%d](%d) != Target dimension[%d](%d)",
i, dim_x[i], i, dim_target[i]));
}
}
auto reduction = ctx->Attrs().Get<std::string>("reduction");
auto reduction_valid = "mean" == reduction || "sum" == reduction ||
"batchmean" == reduction || "none" == reduction;
PADDLE_ENFORCE_EQ(
reduction_valid, true,
platform::errors::InvalidArgument(
"Attr(reduction) can only be 'none'|'batchmean'|'sum'|'mean'."));
if ("none" == reduction) {
ctx->SetOutputDim("Loss", dim_x);
} else {
ctx->SetOutputDim("Loss", {1});
}
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
...@@ -171,8 +135,12 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(KLDivLossGradNoNeedBufferVarInferer, "X"); ...@@ -171,8 +135,12 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(KLDivLossGradNoNeedBufferVarInferer, "X");
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(kldiv_loss, KLDivInferShapeFunctor,
PD_INFER_META(phi::KLDivInferMeta));
REGISTER_OPERATOR(kldiv_loss, ops::KLDivLossOp, ops::KLDivLossOpMaker, REGISTER_OPERATOR(kldiv_loss, ops::KLDivLossOp, ops::KLDivLossOpMaker,
ops::KLDivLossOpGradMaker<paddle::framework::OpDesc>, ops::KLDivLossOpGradMaker<paddle::framework::OpDesc>,
ops::KLDivLossOpGradMaker<paddle::imperative::OpBase>); ops::KLDivLossOpGradMaker<paddle::imperative::OpBase>,
KLDivInferShapeFunctor);
REGISTER_OPERATOR(kldiv_loss_grad, ops::KLDivLossOpGrad, REGISTER_OPERATOR(kldiv_loss_grad, ops::KLDivLossOpGrad,
ops::KLDivLossGradNoNeedBufferVarInferer); ops::KLDivLossGradNoNeedBufferVarInferer);
...@@ -73,6 +73,51 @@ void AllValueCompareInferMeta(const MetaTensor& x, ...@@ -73,6 +73,51 @@ void AllValueCompareInferMeta(const MetaTensor& x,
out->set_dtype(DataType::BOOL); out->set_dtype(DataType::BOOL);
} }
void KLDivInferMeta(const MetaTensor& x,
const MetaTensor& label,
const std::string& reduction,
MetaTensor* out,
MetaConfig config) {
auto dim_x = x.dims();
auto dim_target = label.dims();
PADDLE_ENFORCE_EQ(dim_x.size(),
dim_target.size(),
phi::errors::InvalidArgument(
"Input(X) rank and Input(Target) rank should be "
"same, but received X rank(%d) != Target rank(%d)",
dim_x.size(),
dim_target.size()));
for (int i = 0; i < dim_x.size(); i++) {
if (config.is_runtime || (dim_x[i] > 0 && dim_target[i] > 0)) {
PADDLE_ENFORCE_EQ(
dim_x[i],
dim_target[i],
phi::errors::InvalidArgument(
"Input(X) and Input(Target) should in same shape. but received "
"X dimension[%d](%d) != Target dimension[%d](%d)",
i,
dim_x[i],
i,
dim_target[i]));
}
}
auto reduction_valid = "mean" == reduction || "sum" == reduction ||
"batchmean" == reduction || "none" == reduction;
PADDLE_ENFORCE_EQ(
reduction_valid,
true,
phi::errors::InvalidArgument(
"Attr(reduction) can only be 'none'|'batchmean'|'sum'|'mean'."));
if ("none" == reduction) {
out->set_dims(dim_x);
} else {
out->set_dims({1});
}
out->set_dtype(x.dtype());
}
void Atan2InferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) { void Atan2InferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) {
out->share_meta(x); out->share_meta(x);
} }
......
...@@ -35,6 +35,12 @@ void AllValueCompareInferMeta(const MetaTensor& x, ...@@ -35,6 +35,12 @@ void AllValueCompareInferMeta(const MetaTensor& x,
MetaTensor* out, MetaTensor* out,
MetaConfig config = MetaConfig()); MetaConfig config = MetaConfig());
void KLDivInferMeta(const MetaTensor& x,
const MetaTensor& label,
const std::string& reduction,
MetaTensor* out,
MetaConfig config = MetaConfig());
void Atan2InferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out); void Atan2InferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out);
void BCELossInferMeta(const MetaTensor& input, void BCELossInferMeta(const MetaTensor& input,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册