未验证 提交 a40ea45e 编写于 作者: X xiongkun 提交者: GitHub

[phi] transfer the infer shape of accuracy op into phi (#40358)

* transfer the infershape of accuracy op into phi

* add set_dtype

* add setdtype
上级 431afc39
...@@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/ternary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -21,69 +23,6 @@ class AccuracyOp : public framework::OperatorWithKernel { ...@@ -21,69 +23,6 @@ class AccuracyOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("Out"), true,
platform::errors::NotFound("Input (Out) of AccuracyOp is not found."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Indices"), true,
platform::errors::NotFound(
"Input (Indices) of AccuracyOp is not found."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Label"), true,
platform::errors::NotFound(
"Input (Label) of AccuracyOp is not found."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Accuracy"), true,
platform::errors::NotFound(
"Output (Accuracy) of AccuracyOp is not found."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Correct"), true,
platform::errors::NotFound(
"Output (Correct) of AccuracyOp is not found."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Total"), true,
platform::errors::NotFound(
"Output (Total) of AccuracyOp is not found."));
OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "Accuracy");
OP_INOUT_CHECK(ctx->HasInput("Indices"), "Input", "Indices", "Accuracy");
OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "Accuracy");
OP_INOUT_CHECK(ctx->HasOutput("Accuracy"), "Output", "Accuracy",
"Accuracy");
OP_INOUT_CHECK(ctx->HasOutput("Correct"), "Output", "Correct", "Accuracy");
OP_INOUT_CHECK(ctx->HasOutput("Total"), "Output", "Total", "Accuracy");
auto inference_dim = ctx->GetInputDim("Out");
auto label_dim = ctx->GetInputDim("Label");
// Assume indices has same shape as inference, because
// it's the output of topk.
PADDLE_ENFORCE_EQ(
label_dim.size(), 2,
platform::errors::InvalidArgument(
"ShapeError: label's dimensions of AccuracyOp must be 2. "
"But received label's dimensions = %d, label's shape = [%s]",
label_dim.size(), label_dim));
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(label_dim[1], 1,
platform::errors::InvalidArgument(
"ShapeError: label's second dimension of "
"AccuracyOp must be 1. But received label's "
"second dimension is = %d, label's shape = [%s]",
label_dim[1], label_dim));
PADDLE_ENFORCE_EQ(
inference_dim[0], label_dim[0],
platform::errors::InvalidArgument(
"ShapeError: the output's num_rows of AccuracyOp must be"
" the same as label's num_rows. But received output's "
"shape = [%s], label's shape = [%s], output's num_rows = %d, "
"label's "
"num_rows = %d",
inference_dim, label_dim, inference_dim[0], label_dim[0]));
}
ctx->SetOutputDim("Accuracy", {1});
ctx->SetOutputDim("Correct", {1});
ctx->SetOutputDim("Total", {1});
ctx->ShareLoD("Out", /*->*/ "Accuracy");
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
...@@ -125,8 +64,11 @@ with the input Out(Inference). ...@@ -125,8 +64,11 @@ with the input Out(Inference).
// FIXME(typhoonzero): types of T is for infernece data. // FIXME(typhoonzero): types of T is for infernece data.
// label data is always int. // label data is always int.
DECLARE_INFER_SHAPE_FUNCTOR(accuracy, AccuracyInferShapeFunctor,
PD_INFER_META(phi::AccuracyInferMeta));
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR( REGISTER_OPERATOR(
accuracy, ops::AccuracyOp, ops::AccuracyOpMaker, accuracy, ops::AccuracyOp, ops::AccuracyOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
AccuracyInferShapeFunctor);
...@@ -285,6 +285,58 @@ void LinspaceInferMeta(const MetaTensor& start, ...@@ -285,6 +285,58 @@ void LinspaceInferMeta(const MetaTensor& start,
out->set_dtype(start.dtype()); out->set_dtype(start.dtype());
} }
void AccuracyInferMeta(const MetaTensor& out,
const MetaTensor& indice,
const MetaTensor& label,
MetaTensor* accuracy,
MetaTensor* correct,
MetaTensor* total,
MetaConfig config) {
auto inference_dim = out.dims();
auto label_dim = label.dims();
// Assume indices has same shape as inference, because
// it's the output of topk.
PADDLE_ENFORCE_EQ(
label_dim.size(),
2,
phi::errors::InvalidArgument(
"ShapeError: label's dimensions of AccuracyOp must be 2. "
"But received label's dimensions = %d, label's shape = [%s]",
label_dim.size(),
label_dim));
if (config.is_runtime) {
PADDLE_ENFORCE_EQ(label_dim[1],
1,
phi::errors::InvalidArgument(
"ShapeError: label's second dimension of "
"AccuracyOp must be 1. But received label's "
"second dimension is = %d, label's shape = [%s]",
label_dim[1],
label_dim));
PADDLE_ENFORCE_EQ(
inference_dim[0],
label_dim[0],
phi::errors::InvalidArgument(
"ShapeError: the output's num_rows of AccuracyOp must be"
" the same as label's num_rows. But received output's "
"shape = [%s], label's shape = [%s], output's num_rows = %d, "
"label's "
"num_rows = %d",
inference_dim,
label_dim,
inference_dim[0],
label_dim[0]));
}
accuracy->set_dims({1});
accuracy->set_dtype(out.dtype());
correct->set_dims({1});
correct->set_dtype(out.dtype());
total->set_dims({1});
total->set_dtype(out.dtype());
accuracy->share_lod(out);
}
void GraphSendRecvInferMeta(const MetaTensor& x, void GraphSendRecvInferMeta(const MetaTensor& x,
const MetaTensor& src_index, const MetaTensor& src_index,
const MetaTensor& dst_index, const MetaTensor& dst_index,
......
...@@ -29,6 +29,14 @@ namespace phi { ...@@ -29,6 +29,14 @@ namespace phi {
// NOTE: The name "InferShape" may be not appropriate. "InferMeta" may be good. // NOTE: The name "InferShape" may be not appropriate. "InferMeta" may be good.
// Because functions in this file not only can infer shape, but also need // Because functions in this file not only can infer shape, but also need
// infer lod or other useful data. // infer lod or other useful data.
//
void AccuracyInferMeta(const MetaTensor& out,
const MetaTensor& indice,
const MetaTensor& label,
MetaTensor* accuracy,
MetaTensor* correct,
MetaTensor* total,
MetaConfig config = MetaConfig());
void AddmmInferMeta(const MetaTensor& input, void AddmmInferMeta(const MetaTensor& input,
const MetaTensor& x, const MetaTensor& x,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册