未验证 提交 a40ea45e 编写于 作者: X xiongkun 提交者: GitHub

[phi] transfer the infer shape of accuracy op into phi (#40358)

* transfer the infershape of accuracy op into phi

* add set_dtype

* add setdtype
上级 431afc39
......@@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/ternary.h"
namespace paddle {
namespace operators {
......@@ -21,69 +23,6 @@ class AccuracyOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("Out"), true,
platform::errors::NotFound("Input (Out) of AccuracyOp is not found."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Indices"), true,
platform::errors::NotFound(
"Input (Indices) of AccuracyOp is not found."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Label"), true,
platform::errors::NotFound(
"Input (Label) of AccuracyOp is not found."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Accuracy"), true,
platform::errors::NotFound(
"Output (Accuracy) of AccuracyOp is not found."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Correct"), true,
platform::errors::NotFound(
"Output (Correct) of AccuracyOp is not found."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Total"), true,
platform::errors::NotFound(
"Output (Total) of AccuracyOp is not found."));
OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "Accuracy");
OP_INOUT_CHECK(ctx->HasInput("Indices"), "Input", "Indices", "Accuracy");
OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "Accuracy");
OP_INOUT_CHECK(ctx->HasOutput("Accuracy"), "Output", "Accuracy",
"Accuracy");
OP_INOUT_CHECK(ctx->HasOutput("Correct"), "Output", "Correct", "Accuracy");
OP_INOUT_CHECK(ctx->HasOutput("Total"), "Output", "Total", "Accuracy");
auto inference_dim = ctx->GetInputDim("Out");
auto label_dim = ctx->GetInputDim("Label");
// Assume indices has same shape as inference, because
// it's the output of topk.
PADDLE_ENFORCE_EQ(
label_dim.size(), 2,
platform::errors::InvalidArgument(
"ShapeError: label's dimensions of AccuracyOp must be 2. "
"But received label's dimensions = %d, label's shape = [%s]",
label_dim.size(), label_dim));
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(label_dim[1], 1,
platform::errors::InvalidArgument(
"ShapeError: label's second dimension of "
"AccuracyOp must be 1. But received label's "
"second dimension is = %d, label's shape = [%s]",
label_dim[1], label_dim));
PADDLE_ENFORCE_EQ(
inference_dim[0], label_dim[0],
platform::errors::InvalidArgument(
"ShapeError: the output's num_rows of AccuracyOp must be"
" the same as label's num_rows. But received output's "
"shape = [%s], label's shape = [%s], output's num_rows = %d, "
"label's "
"num_rows = %d",
inference_dim, label_dim, inference_dim[0], label_dim[0]));
}
ctx->SetOutputDim("Accuracy", {1});
ctx->SetOutputDim("Correct", {1});
ctx->SetOutputDim("Total", {1});
ctx->ShareLoD("Out", /*->*/ "Accuracy");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
......@@ -125,8 +64,11 @@ with the input Out(Inference).
// FIXME(typhoonzero): types of T is for infernece data.
// label data is always int.
DECLARE_INFER_SHAPE_FUNCTOR(accuracy, AccuracyInferShapeFunctor,
PD_INFER_META(phi::AccuracyInferMeta));
namespace ops = paddle::operators;
REGISTER_OPERATOR(
accuracy, ops::AccuracyOp, ops::AccuracyOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
AccuracyInferShapeFunctor);
......@@ -285,6 +285,58 @@ void LinspaceInferMeta(const MetaTensor& start,
out->set_dtype(start.dtype());
}
void AccuracyInferMeta(const MetaTensor& out,
const MetaTensor& indice,
const MetaTensor& label,
MetaTensor* accuracy,
MetaTensor* correct,
MetaTensor* total,
MetaConfig config) {
auto inference_dim = out.dims();
auto label_dim = label.dims();
// Assume indices has same shape as inference, because
// it's the output of topk.
PADDLE_ENFORCE_EQ(
label_dim.size(),
2,
phi::errors::InvalidArgument(
"ShapeError: label's dimensions of AccuracyOp must be 2. "
"But received label's dimensions = %d, label's shape = [%s]",
label_dim.size(),
label_dim));
if (config.is_runtime) {
PADDLE_ENFORCE_EQ(label_dim[1],
1,
phi::errors::InvalidArgument(
"ShapeError: label's second dimension of "
"AccuracyOp must be 1. But received label's "
"second dimension is = %d, label's shape = [%s]",
label_dim[1],
label_dim));
PADDLE_ENFORCE_EQ(
inference_dim[0],
label_dim[0],
phi::errors::InvalidArgument(
"ShapeError: the output's num_rows of AccuracyOp must be"
" the same as label's num_rows. But received output's "
"shape = [%s], label's shape = [%s], output's num_rows = %d, "
"label's "
"num_rows = %d",
inference_dim,
label_dim,
inference_dim[0],
label_dim[0]));
}
accuracy->set_dims({1});
accuracy->set_dtype(out.dtype());
correct->set_dims({1});
correct->set_dtype(out.dtype());
total->set_dims({1});
total->set_dtype(out.dtype());
accuracy->share_lod(out);
}
void GraphSendRecvInferMeta(const MetaTensor& x,
const MetaTensor& src_index,
const MetaTensor& dst_index,
......
......@@ -29,6 +29,14 @@ namespace phi {
// NOTE: The name "InferShape" may be not appropriate. "InferMeta" may be good.
// Because functions in this file not only can infer shape, but also need
// infer lod or other useful data.
//
void AccuracyInferMeta(const MetaTensor& out,
const MetaTensor& indice,
const MetaTensor& label,
MetaTensor* accuracy,
MetaTensor* correct,
MetaTensor* total,
MetaConfig config = MetaConfig());
void AddmmInferMeta(const MetaTensor& input,
const MetaTensor& x,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册