未验证 提交 fe1cc8bd 编写于 作者: L Linjie Chen 提交者: GitHub

[phi] move sigmoid_cross_entopy_with_logits log_loss cumsum auc infershape to phi (#40200)

* move infershapes to phi

* update code format

* update code format
上级 1f857cb9
...@@ -12,8 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -21,17 +24,6 @@ namespace operators { ...@@ -21,17 +24,6 @@ namespace operators {
class CumOp : public framework::OperatorWithKernel { class CumOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
if (ctx->Attrs().Get<bool>("flatten")) {
ctx->SetOutputDim("Out",
phi::make_ddim({phi::product(ctx->GetInputDim("X"))}));
} else {
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
}
ctx->ShareLoD("X", /*->*/ "Out");
}
}; };
class CumsumOpMaker : public framework::OpProtoAndCheckerMaker { class CumsumOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -87,10 +79,12 @@ class CumsumGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -87,10 +79,12 @@ class CumsumGradMaker : public framework::SingleGradOpMaker<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext; using CPU = paddle::platform::CPUDeviceContext;
DECLARE_INFER_SHAPE_FUNCTOR(cumsum, CumsumInferShapeFunctor,
PD_INFER_META(phi::CumsumInferMeta));
REGISTER_OPERATOR(cumsum, ops::CumOp, ops::CumsumOpMaker, REGISTER_OPERATOR(cumsum, ops::CumOp, ops::CumsumOpMaker,
ops::CumsumGradMaker<paddle::framework::OpDesc>, ops::CumsumGradMaker<paddle::framework::OpDesc>,
ops::CumsumGradMaker<paddle::imperative::OpBase>); ops::CumsumGradMaker<paddle::imperative::OpBase>,
CumsumInferShapeFunctor);
REGISTER_OP_VERSION(cumsum) REGISTER_OP_VERSION(cumsum)
.AddCheckpoint( .AddCheckpoint(
......
...@@ -13,7 +13,10 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <memory> #include <memory>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -21,43 +24,6 @@ namespace operators { ...@@ -21,43 +24,6 @@ namespace operators {
class LogLossOp : public framework::OperatorWithKernel { class LogLossOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Predicted"), "Input", "Predicted", "LogLoss");
OP_INOUT_CHECK(ctx->HasInput("Labels"), "Input", "Labels", "LogLoss");
auto pred_dims = ctx->GetInputDim("Predicted");
auto label_dims = ctx->GetInputDim("Labels");
if (ctx->IsRuntime() ||
(phi::product(pred_dims) > 0 && phi::product(label_dims) > 0)) {
PADDLE_ENFORCE_EQ(
pred_dims, label_dims,
platform::errors::InvalidArgument(
"The dimensions of Input(Predicted) must be equal to the"
"dimensions of Input(Labels), but received dimensions of "
"Input(Predicted)"
"is [%s], received dimensions of Input(Labels) is [%s].",
pred_dims, label_dims));
}
PADDLE_ENFORCE_EQ(pred_dims.size(), 2,
platform::errors::InvalidArgument(
"The dimensions of Input(Predicted) must be 2,"
"But received dimensions of Input(Predicted)"
"is [%d]",
pred_dims.size()));
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(
pred_dims[1], 1,
platform::errors::InvalidArgument(
"Each row of Input(Predicted) contains a real value, "
"so the 2nd dimension of Input(X) must be 1,"
"But got [%d]",
pred_dims[1]));
}
ctx->SetOutputDim("Loss", {pred_dims[0], 1});
ctx->ShareLoD("Predicted", "Loss");
}
}; };
template <typename AttrType> template <typename AttrType>
...@@ -145,7 +111,10 @@ class LogLossGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -145,7 +111,10 @@ class LogLossGradMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(log_loss, LogLossInferShapeFunctor,
PD_INFER_META(phi::LogLossInferMeta));
REGISTER_OPERATOR(log_loss, ops::LogLossOp, ops::LogLossOpMaker<float>, REGISTER_OPERATOR(log_loss, ops::LogLossOp, ops::LogLossOpMaker<float>,
ops::LogLossGradMaker<paddle::framework::OpDesc>, ops::LogLossGradMaker<paddle::framework::OpDesc>,
ops::LogLossGradMaker<paddle::imperative::OpBase>); ops::LogLossGradMaker<paddle::imperative::OpBase>,
LogLossInferShapeFunctor);
REGISTER_OPERATOR(log_loss_grad, ops::LogLossGradOp); REGISTER_OPERATOR(log_loss_grad, ops::LogLossGradOp);
...@@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -21,70 +24,6 @@ class AucOp : public framework::OperatorWithKernel { ...@@ -21,70 +24,6 @@ class AucOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Predict"), "Input", "Predict", "Auc");
OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "Auc");
auto predict_dims = ctx->GetInputDim("Predict");
auto label_dims = ctx->GetInputDim("Label");
PADDLE_ENFORCE_GE(
predict_dims.size(), 2,
platform::errors::InvalidArgument(
"The Input(Predict) has not been initialized properly. The "
"shape of Input(Predict) = [%s], the shape size must be "
"greater_equal 2.",
predict_dims));
auto predict_width = predict_dims[1];
PADDLE_ENFORCE_NE(
phi::product(predict_dims), 0,
platform::errors::InvalidArgument(
"The Input(Predict) has not been initialized properly. The "
"shape of Input(Predict) = [%s], the shape can not involes 0.",
predict_dims));
PADDLE_ENFORCE_NE(
phi::product(label_dims), 0,
platform::errors::InvalidArgument(
"The Input(Label) has not been initialized properly. The "
"shape of Input(Label) = [%s], the shape can not involes 0.",
label_dims));
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_LE(predict_width, 2,
platform::errors::InvalidArgument(
"Only support binary classification,"
"prediction dims[1] should be 1 or 2"));
}
auto predict_height = ctx->GetInputDim("Predict")[0];
auto label_height = ctx->GetInputDim("Label")[0];
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(predict_height, label_height,
platform::errors::InvalidArgument(
"Out and Label should have same height."));
}
int num_pred_buckets = ctx->Attrs().Get<int>("num_thresholds") + 1;
int slide_steps = ctx->Attrs().Get<int>("slide_steps");
PADDLE_ENFORCE_GE(
num_pred_buckets, 1,
platform::errors::InvalidArgument("num_thresholds must larger than 1"));
PADDLE_ENFORCE_GE(slide_steps, 0,
platform::errors::InvalidArgument(
"slide_steps must be natural number"));
ctx->SetOutputDim("AUC", {1});
if (slide_steps) {
ctx->SetOutputDim("StatPosOut",
{(1 + slide_steps) * num_pred_buckets + 1});
ctx->SetOutputDim("StatNegOut",
{(1 + slide_steps) * num_pred_buckets + 1});
} else {
ctx->SetOutputDim("StatPosOut", {1, num_pred_buckets});
ctx->SetOutputDim("StatNegOut", {1, num_pred_buckets});
}
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
...@@ -145,4 +84,7 @@ There are two types of possible curves: ...@@ -145,4 +84,7 @@ There are two types of possible curves:
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(auc, ops::AucOp, ops::AucOpMaker); DECLARE_INFER_SHAPE_FUNCTOR(auc, AucInferShapeFunctor,
PD_INFER_META(phi::AucInferMeta));
REGISTER_OP_WITHOUT_GRADIENT(auc, ops::AucOp, ops::AucOpMaker,
AucInferShapeFunctor);
...@@ -15,7 +15,10 @@ limitations under the License. */ ...@@ -15,7 +15,10 @@ limitations under the License. */
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -26,46 +29,6 @@ const int kIgnoreIndex = -100; ...@@ -26,46 +29,6 @@ const int kIgnoreIndex = -100;
class SigmoidCrossEntropyWithLogitsOp : public framework::OperatorWithKernel { class SigmoidCrossEntropyWithLogitsOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X",
"SigmoidCrossEntropyWithLogitsOp");
OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label",
"SigmoidCrossEntropyWithLogitsOp");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
"SigmoidCrossEntropyWithLogitsOp");
auto x_dims = ctx->GetInputDim("X");
auto labels_dims = ctx->GetInputDim("Label");
int rank = x_dims.size();
PADDLE_ENFORCE_EQ(rank, labels_dims.size(),
platform::errors::InvalidArgument(
"Input(X) and Input(Label) shall have the same rank."
"But received: the rank of Input(X) is [%d], "
"the rank of Input(Label) is [%d].",
rank, labels_dims.size()));
bool check = true;
if ((!ctx->IsRuntime()) &&
(phi::product(x_dims) <= 0 || phi::product(labels_dims) <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(
phi::slice_ddim(x_dims, 0, rank),
phi::slice_ddim(labels_dims, 0, rank),
platform::errors::InvalidArgument(
"Input(X) and Input(Label) shall have the same shape "
"except the last dimension. But received: the shape of "
"Input(X) is [%s], the shape of Input(Label) is [%s].",
x_dims, labels_dims));
}
ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out");
}
}; };
class SigmoidCrossEntropyWithLogitsGradOp class SigmoidCrossEntropyWithLogitsGradOp
...@@ -201,12 +164,17 @@ DECLARE_INPLACE_OP_INFERER(SigmoidCrossEntropyWithLogitsGradInplaceInferer, ...@@ -201,12 +164,17 @@ DECLARE_INPLACE_OP_INFERER(SigmoidCrossEntropyWithLogitsGradInplaceInferer,
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(
sigmoid_cross_entropy_with_logits,
SigmoidCrossEntropyWithLogitsInferShapeFunctor,
PD_INFER_META(phi::SigmoidCrossEntropyWithLogitsInferMeta));
REGISTER_OPERATOR( REGISTER_OPERATOR(
sigmoid_cross_entropy_with_logits, ops::SigmoidCrossEntropyWithLogitsOp, sigmoid_cross_entropy_with_logits, ops::SigmoidCrossEntropyWithLogitsOp,
ops::SigmoidCrossEntropyWithLogitsOpMaker, ops::SigmoidCrossEntropyWithLogitsOpMaker,
ops::SigmoidCrossEntropyWithLogitsGradOpMaker<paddle::framework::OpDesc>, ops::SigmoidCrossEntropyWithLogitsGradOpMaker<paddle::framework::OpDesc>,
ops::SigmoidCrossEntropyWithLogitsGradOpMaker<paddle::imperative::OpBase>, ops::SigmoidCrossEntropyWithLogitsGradOpMaker<paddle::imperative::OpBase>,
ops::SigmoidCrossEntropyWithLogitsInplaceInferer); ops::SigmoidCrossEntropyWithLogitsInplaceInferer,
SigmoidCrossEntropyWithLogitsInferShapeFunctor);
REGISTER_OPERATOR(sigmoid_cross_entropy_with_logits_grad, REGISTER_OPERATOR(sigmoid_cross_entropy_with_logits_grad,
ops::SigmoidCrossEntropyWithLogitsGradOp, ops::SigmoidCrossEntropyWithLogitsGradOp,
ops::SigmoidCrossEntropyWithLogitsGradInplaceInferer); ops::SigmoidCrossEntropyWithLogitsGradInplaceInferer);
...@@ -575,6 +575,48 @@ void GatherTreeMeta(const MetaTensor& ids, ...@@ -575,6 +575,48 @@ void GatherTreeMeta(const MetaTensor& ids,
out->set_dims(ids_dims); out->set_dims(ids_dims);
} }
void LogLossInferMeta(const MetaTensor& input,
const MetaTensor& label,
float epsilon,
MetaTensor* out,
MetaConfig config) {
auto pred_dims = input.dims();
auto label_dims = label.dims();
if (config.is_runtime ||
(phi::product(pred_dims) > 0 && phi::product(label_dims) > 0)) {
PADDLE_ENFORCE_EQ(
pred_dims,
label_dims,
phi::errors::InvalidArgument(
"The dimensions of Input(Predicted) must be equal to the"
"dimensions of Input(Labels), but received dimensions of "
"Input(Predicted)"
"is [%s], received dimensions of Input(Labels) is [%s].",
pred_dims,
label_dims));
}
PADDLE_ENFORCE_EQ(pred_dims.size(),
2,
phi::errors::InvalidArgument(
"The dimensions of Input(Predicted) must be 2,"
"But received dimensions of Input(Predicted)"
"is [%d]",
pred_dims.size()));
if (config.is_runtime) {
PADDLE_ENFORCE_EQ(pred_dims[1],
1,
phi::errors::InvalidArgument(
"Each row of Input(Predicted) contains a real value, "
"so the 2nd dimension of Input(X) must be 1,"
"But got [%d]",
pred_dims[1]));
}
out->set_dims({pred_dims[0], 1});
out->set_dtype(input.dtype());
out->share_lod(input);
}
void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out) { void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out) {
auto dim_x = x.dims(); auto dim_x = x.dims();
auto dim_vec = vec.dims(); auto dim_vec = vec.dims();
...@@ -605,4 +647,45 @@ void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out) { ...@@ -605,4 +647,45 @@ void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out) {
out->share_lod(x); out->share_lod(x);
} }
void SigmoidCrossEntropyWithLogitsInferMeta(const MetaTensor& x,
const MetaTensor& label,
bool normalize,
int ignore_index,
MetaTensor* out,
MetaConfig config) {
auto x_dims = x.dims();
auto labels_dims = label.dims();
int rank = x_dims.size();
PADDLE_ENFORCE_EQ(rank,
labels_dims.size(),
phi::errors::InvalidArgument(
"Input(X) and Input(Label) shall have the same rank."
"But received: the rank of Input(X) is [%d], "
"the rank of Input(Label) is [%d].",
rank,
labels_dims.size()));
bool check = true;
if ((!config.is_runtime) &&
(phi::product(x_dims) <= 0 || phi::product(labels_dims) <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(
phi::slice_ddim(x_dims, 0, rank),
phi::slice_ddim(labels_dims, 0, rank),
phi::errors::InvalidArgument(
"Input(X) and Input(Label) shall have the same shape "
"except the last dimension. But received: the shape of "
"Input(X) is [%s], the shape of Input(Label) is [%s].",
x_dims,
labels_dims));
}
out->set_dims(x_dims);
out->set_dtype(x.dtype());
out->share_lod(x);
}
} // namespace phi } // namespace phi
...@@ -89,6 +89,7 @@ void BincountInferMeta(const MetaTensor& x, ...@@ -89,6 +89,7 @@ void BincountInferMeta(const MetaTensor& x,
const paddle::optional<const MetaTensor&> weights, const paddle::optional<const MetaTensor&> weights,
int minlength, int minlength,
MetaTensor* out); MetaTensor* out);
void DistInferMeta(const MetaTensor& x, void DistInferMeta(const MetaTensor& x,
const MetaTensor& y, const MetaTensor& y,
float p, float p,
...@@ -102,6 +103,19 @@ void GatherTreeMeta(const MetaTensor& ids, ...@@ -102,6 +103,19 @@ void GatherTreeMeta(const MetaTensor& ids,
const MetaTensor& parents, const MetaTensor& parents,
MetaTensor* out); MetaTensor* out);
void LogLossInferMeta(const MetaTensor& input,
const MetaTensor& label,
float epsilon,
MetaTensor* out,
MetaConfig config = MetaConfig());
void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out); void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out);
void SigmoidCrossEntropyWithLogitsInferMeta(const MetaTensor& x,
const MetaTensor& label,
bool normalize,
int ignore_index,
MetaTensor* out,
MetaConfig config = MetaConfig());
} // namespace phi } // namespace phi
...@@ -28,6 +28,86 @@ std::vector<DDim> GetMetaTensorsDim(const std::vector<MetaTensor*>& tensors) { ...@@ -28,6 +28,86 @@ std::vector<DDim> GetMetaTensorsDim(const std::vector<MetaTensor*>& tensors) {
return dims; return dims;
} }
void AucInferMeta(const MetaTensor& input,
const MetaTensor& label,
const MetaTensor& stat_pos,
const MetaTensor& stat_neg,
const std::string& curve,
int num_thresholds,
int slide_steps,
MetaTensor* auc,
MetaTensor* stat_pos_out,
MetaTensor* stat_neg_out,
MetaConfig config) {
auto predict_dims = input.dims();
auto label_dims = label.dims();
PADDLE_ENFORCE_GE(
predict_dims.size(),
2,
phi::errors::InvalidArgument(
"The Input(Predict) has not been initialized properly. The "
"shape of Input(Predict) = [%s], the shape size must be "
"greater_equal 2.",
predict_dims));
auto predict_width = predict_dims[1];
PADDLE_ENFORCE_NE(
phi::product(predict_dims),
0,
phi::errors::InvalidArgument(
"The Input(Predict) has not been initialized properly. The "
"shape of Input(Predict) = [%s], the shape can not involes 0.",
predict_dims));
PADDLE_ENFORCE_NE(
phi::product(label_dims),
0,
phi::errors::InvalidArgument(
"The Input(Label) has not been initialized properly. The "
"shape of Input(Label) = [%s], the shape can not involes 0.",
label_dims));
if (config.is_runtime) {
PADDLE_ENFORCE_LE(
predict_width,
2,
phi::errors::InvalidArgument("Only support binary classification,"
"prediction dims[1] should be 1 or 2"));
}
auto predict_height = input.dims()[0];
auto label_height = label.dims()[0];
if (config.is_runtime) {
PADDLE_ENFORCE_EQ(
predict_height,
label_height,
phi::errors::InvalidArgument("Out and Label should have same height."));
}
int num_pred_buckets = num_thresholds + 1;
PADDLE_ENFORCE_GE(
num_pred_buckets,
1,
phi::errors::InvalidArgument("num_thresholds must larger than 1"));
PADDLE_ENFORCE_GE(
slide_steps,
0,
phi::errors::InvalidArgument("slide_steps must be natural number"));
auc->set_dims({1});
auc->set_dtype(DataType::INT64);
if (slide_steps) {
stat_pos_out->set_dims({(1 + slide_steps) * num_pred_buckets + 1});
stat_pos_out->set_dtype(DataType::INT64);
stat_neg_out->set_dims({(1 + slide_steps) * num_pred_buckets + 1});
stat_neg_out->set_dtype(DataType::INT64);
} else {
stat_pos_out->set_dims({1, num_pred_buckets});
stat_pos_out->set_dtype(DataType::INT64);
stat_neg_out->set_dims({1, num_pred_buckets});
stat_neg_out->set_dtype(DataType::INT64);
}
}
void AdamaxInferMeta(const MetaTensor& param, void AdamaxInferMeta(const MetaTensor& param,
const MetaTensor& grad, const MetaTensor& grad,
const MetaTensor& learning_rate, const MetaTensor& learning_rate,
......
...@@ -20,6 +20,18 @@ namespace phi { ...@@ -20,6 +20,18 @@ namespace phi {
std::vector<DDim> GetMetaTensorsDim(const std::vector<MetaTensor*>& tensors); std::vector<DDim> GetMetaTensorsDim(const std::vector<MetaTensor*>& tensors);
void AucInferMeta(const MetaTensor& input,
const MetaTensor& label,
const MetaTensor& stat_pos,
const MetaTensor& stat_neg,
const std::string& curve,
int num_thresholds,
int slide_steps,
MetaTensor* auc,
MetaTensor* stat_pos_out,
MetaTensor* stat_neg_out,
MetaConfig config = MetaConfig());
void BilinearTensorProductInferMeta(const MetaTensor& x, void BilinearTensorProductInferMeta(const MetaTensor& x,
const MetaTensor& y, const MetaTensor& y,
const MetaTensor& weight, const MetaTensor& weight,
......
...@@ -156,6 +156,24 @@ void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out) { ...@@ -156,6 +156,24 @@ void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out) {
out->set_layout(x.layout()); out->set_layout(x.layout());
} }
void CumsumInferMeta(const MetaTensor& x,
int axis,
bool flatten,
bool exclusive,
bool reverse,
MetaTensor* out) {
auto x_dims = x.dims();
if (flatten) {
out->set_dims(phi::make_ddim({phi::product(x_dims)}));
out->set_dtype(x.dtype());
} else {
out->set_dims(x_dims);
out->set_dtype(x.dtype());
}
out->share_lod(x);
}
void IncrementInferMeta(const MetaTensor& x, float value, MetaTensor* out) { void IncrementInferMeta(const MetaTensor& x, float value, MetaTensor* out) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
product(x.dims()), product(x.dims()),
......
...@@ -63,6 +63,13 @@ void CopyToInferMeta(const MetaTensor& x, ...@@ -63,6 +63,13 @@ void CopyToInferMeta(const MetaTensor& x,
void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out); void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out);
void CumsumInferMeta(const MetaTensor& x,
int axis,
bool flatten,
bool exclusive,
bool reverse,
MetaTensor* out);
void IncrementInferMeta(const MetaTensor& x, float value, MetaTensor* out); void IncrementInferMeta(const MetaTensor& x, float value, MetaTensor* out);
void InferMetaFromVecValue(const MetaTensor& x, void InferMetaFromVecValue(const MetaTensor& x,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册