提交 a308ff29 编写于 作者: Q qijun

make infershape of feedop and fetchop compatible with compile-time design

上级 975a5129
...@@ -116,12 +116,8 @@ class ExecutorTesterRandom : public ::testing::Test { ...@@ -116,12 +116,8 @@ class ExecutorTesterRandom : public ::testing::Test {
{{"dims", std::vector<int>{input_dim, embed_dim}}}, init_root_block); {{"dims", std::vector<int>{input_dim, embed_dim}}}, init_root_block);
AddOp("gaussian_random", {}, {{"Out", {"w2"}}}, AddOp("gaussian_random", {}, {{"Out", {"w2"}}},
{{"dims", std::vector<int>{embed_dim, input_dim}}}, init_root_block); {{"dims", std::vector<int>{embed_dim, input_dim}}}, init_root_block);
AddOp("fetch", {{"Input", {"w1"}}}, {}, AddOp("fetch", {{"Input", {"w1"}}}, {}, {{"col", 0}}, init_root_block);
{{"dims", std::vector<int>{input_dim, embed_dim}}, {"col", 0}}, AddOp("fetch", {{"Input", {"w2"}}}, {}, {{"col", 1}}, init_root_block);
init_root_block);
AddOp("fetch", {{"Input", {"w2"}}}, {},
{{"dims", std::vector<int>{embed_dim, input_dim}}, {"col", 1}},
init_root_block);
// flush // flush
init_program.Proto(); init_program.Proto();
...@@ -163,12 +159,8 @@ class ExecutorTesterRandom : public ::testing::Test { ...@@ -163,12 +159,8 @@ class ExecutorTesterRandom : public ::testing::Test {
{"Grad", {"w2@GRAD"}}}, {"Grad", {"w2@GRAD"}}},
{{"ParamOut", {"w2"}}}, {}, root_block); {{"ParamOut", {"w2"}}}, {}, root_block);
AddOp("fetch", {{"Input", {"w1"}}}, {}, AddOp("fetch", {{"Input", {"w1"}}}, {}, {{"col", 0}}, root_block);
{{"dims", std::vector<int>{input_dim, embed_dim}}, {"col", 0}}, AddOp("fetch", {{"Input", {"w2"}}}, {}, {{"col", 1}}, root_block);
root_block);
AddOp("fetch", {{"Input", {"w2"}}}, {},
{{"dims", std::vector<int>{embed_dim, input_dim}}, {"col", 1}},
root_block);
// flush // flush
program.Proto(); program.Proto();
...@@ -197,10 +189,8 @@ class ExecutorTesterFeedAndFetch : public ::testing::Test { ...@@ -197,10 +189,8 @@ class ExecutorTesterFeedAndFetch : public ::testing::Test {
root_block); root_block);
AddOp("feed", {}, {{"Out", {"b"}}}, {{"dims", dim}, {"col", 1}}, AddOp("feed", {}, {{"Out", {"b"}}}, {{"dims", dim}, {"col", 1}},
root_block); root_block);
AddOp("fetch", {{"Input", {"a"}}}, {}, {{"dims", dim}, {"col", 0}}, AddOp("fetch", {{"Input", {"a"}}}, {}, {{"col", 0}}, root_block);
root_block); AddOp("fetch", {{"Input", {"b"}}}, {}, {{"col", 1}}, root_block);
AddOp("fetch", {{"Input", {"b"}}}, {}, {{"dims", dim}, {"col", 1}},
root_block);
// flush // flush
program.Proto(); program.Proto();
......
...@@ -24,15 +24,6 @@ class FeedOp : public framework::OperatorWithKernel { ...@@ -24,15 +24,6 @@ class FeedOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output should be not null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output should be not null.");
int col = ctx->Attrs().Get<int>("col");
framework::Variable* g_feed_variable =
framework::GetGlobalScope()->FindVar("feed_value");
const auto& tensors =
g_feed_variable->Get<std::vector<framework::Tensor>>();
PADDLE_ENFORCE_GT(tensors.size(), static_cast<size_t>(col));
auto& shape = ctx->Attrs().Get<std::vector<int>>("dims"); auto& shape = ctx->Attrs().Get<std::vector<int>>("dims");
std::vector<int64_t> shape_int64(shape.size(), 0); std::vector<int64_t> shape_int64(shape.size(), 0);
std::transform(shape.begin(), shape.end(), shape_int64.begin(), std::transform(shape.begin(), shape.end(), shape_int64.begin(),
...@@ -43,7 +34,7 @@ class FeedOp : public framework::OperatorWithKernel { ...@@ -43,7 +34,7 @@ class FeedOp : public framework::OperatorWithKernel {
framework::DataType IndicateDataType( framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return static_cast<framework::DataType>(Attr<int>("data_type")); return static_cast<framework::DataType>(Attr<int>("dataType"));
} }
}; };
...@@ -51,7 +42,7 @@ class FeedOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -51,7 +42,7 @@ class FeedOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
FeedOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) FeedOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<int>("data_type", "output data type") AddAttr<int>("dataType", "output data type")
.SetDefault(framework::DataType::FP32); .SetDefault(framework::DataType::FP32);
AddAttr<int>("col", "The col in global feed variable").SetDefault(0); AddAttr<int>("col", "The col in global feed variable").SetDefault(0);
AddAttr<std::vector<int>>("dims", "The dimension of feed tensor."); AddAttr<std::vector<int>>("dims", "The dimension of feed tensor.");
......
...@@ -27,9 +27,10 @@ class FeedKernel : public framework::OpKernel<T> { ...@@ -27,9 +27,10 @@ class FeedKernel : public framework::OpKernel<T> {
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
framework::Variable* g_feed_variable = framework::Variable* g_feed_variable =
framework::GetGlobalScope()->FindVar("feed_value"); framework::GetGlobalScope()->FindVar("feed_value");
int col = ctx.template Attr<int>("col");
const auto& tensors = const auto& tensors =
g_feed_variable->Get<std::vector<framework::Tensor>>(); g_feed_variable->Get<std::vector<framework::Tensor>>();
int col = ctx.template Attr<int>("col");
PADDLE_ENFORCE_GT(tensors.size(), static_cast<size_t>(col));
out->CopyFrom<T>(tensors[col], ctx.GetPlace()); out->CopyFrom<T>(tensors[col], ctx.GetPlace());
} }
}; };
......
...@@ -24,26 +24,11 @@ class FetchOp : public framework::OperatorWithKernel { ...@@ -24,26 +24,11 @@ class FetchOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"), "Input should be not null."); PADDLE_ENFORCE(ctx->HasInput("Input"), "Input should be not null.");
int col = ctx->Attrs().Get<int>("col");
framework::Variable* g_fetch_variable =
framework::GetGlobalScope()->FindVar("fetch_value");
auto* tensors =
g_fetch_variable->GetMutable<std::vector<framework::Tensor>>();
if (tensors->size() < static_cast<size_t>(col + 1)) {
tensors->resize(col + 1);
}
auto input_dim = ctx->GetInputDim("Input");
PADDLE_ENFORCE_GT(tensors->size(), col);
(*tensors)[col].Resize(input_dim);
// TODO(qijun): need to handle LodTensor later
} }
framework::DataType IndicateDataType( framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return static_cast<framework::DataType>(Attr<int>("data_type")); return static_cast<framework::DataType>(Attr<int>("dataType"));
} }
}; };
...@@ -51,10 +36,9 @@ class FetchOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -51,10 +36,9 @@ class FetchOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
FetchOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) FetchOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<int>("data_type", "output data type") AddAttr<int>("dataType", "output data type")
.SetDefault(framework::DataType::FP32); .SetDefault(framework::DataType::FP32);
AddAttr<int>("col", "The col in global fetch variable").SetDefault(0); AddAttr<int>("col", "The col in global fetch variable").SetDefault(0);
AddAttr<std::vector<int>>("dims", "The dimension of fetch tensor.");
AddInput("Input", "The output of fetch op."); AddInput("Input", "The output of fetch op.");
AddComment(R"DOC(Fetch data to global fetch variable)DOC"); AddComment(R"DOC(Fetch data to global fetch variable)DOC");
} }
......
...@@ -24,13 +24,19 @@ class FetchKernel : public framework::OpKernel<T> { ...@@ -24,13 +24,19 @@ class FetchKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
const framework::Tensor* input = ctx.Input<framework::Tensor>("Input"); const framework::Tensor* input = ctx.Input<framework::Tensor>("Input");
int col = ctx.template Attr<int>("col");
framework::Variable* g_fetch_variable = framework::Variable* g_fetch_variable =
framework::GetGlobalScope()->FindVar("fetch_value"); framework::GetGlobalScope()->FindVar("fetch_value");
auto* tensors = auto* tensors =
g_fetch_variable->GetMutable<std::vector<framework::Tensor>>(); g_fetch_variable->GetMutable<std::vector<framework::Tensor>>();
int col = ctx.template Attr<int>("col");
if (tensors->size() < static_cast<size_t>(col + 1)) {
tensors->resize(col + 1);
}
PADDLE_ENFORCE_GT(tensors->size(), static_cast<size_t>(col));
(*tensors)[col].Resize(input->dims());
(*tensors)[col].mutable_data<T>(platform::CPUPlace()); (*tensors)[col].mutable_data<T>(platform::CPUPlace());
(*tensors)[col].CopyFrom<T>(*input, platform::CPUPlace()); (*tensors)[col].CopyFrom<T>(*input, platform::CPUPlace());
// TODO(qijun): need to handle LodTensor later
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册