提交 a308ff29 编写于 作者: Q qijun

make infershape of feedop and fetchop compatible with compile-time design

上级 975a5129
......@@ -116,12 +116,8 @@ class ExecutorTesterRandom : public ::testing::Test {
{{"dims", std::vector<int>{input_dim, embed_dim}}}, init_root_block);
AddOp("gaussian_random", {}, {{"Out", {"w2"}}},
{{"dims", std::vector<int>{embed_dim, input_dim}}}, init_root_block);
AddOp("fetch", {{"Input", {"w1"}}}, {},
{{"dims", std::vector<int>{input_dim, embed_dim}}, {"col", 0}},
init_root_block);
AddOp("fetch", {{"Input", {"w2"}}}, {},
{{"dims", std::vector<int>{embed_dim, input_dim}}, {"col", 1}},
init_root_block);
AddOp("fetch", {{"Input", {"w1"}}}, {}, {{"col", 0}}, init_root_block);
AddOp("fetch", {{"Input", {"w2"}}}, {}, {{"col", 1}}, init_root_block);
// flush
init_program.Proto();
......@@ -163,12 +159,8 @@ class ExecutorTesterRandom : public ::testing::Test {
{"Grad", {"w2@GRAD"}}},
{{"ParamOut", {"w2"}}}, {}, root_block);
AddOp("fetch", {{"Input", {"w1"}}}, {},
{{"dims", std::vector<int>{input_dim, embed_dim}}, {"col", 0}},
root_block);
AddOp("fetch", {{"Input", {"w2"}}}, {},
{{"dims", std::vector<int>{embed_dim, input_dim}}, {"col", 1}},
root_block);
AddOp("fetch", {{"Input", {"w1"}}}, {}, {{"col", 0}}, root_block);
AddOp("fetch", {{"Input", {"w2"}}}, {}, {{"col", 1}}, root_block);
// flush
program.Proto();
......@@ -197,10 +189,8 @@ class ExecutorTesterFeedAndFetch : public ::testing::Test {
root_block);
AddOp("feed", {}, {{"Out", {"b"}}}, {{"dims", dim}, {"col", 1}},
root_block);
AddOp("fetch", {{"Input", {"a"}}}, {}, {{"dims", dim}, {"col", 0}},
root_block);
AddOp("fetch", {{"Input", {"b"}}}, {}, {{"dims", dim}, {"col", 1}},
root_block);
AddOp("fetch", {{"Input", {"a"}}}, {}, {{"col", 0}}, root_block);
AddOp("fetch", {{"Input", {"b"}}}, {}, {{"col", 1}}, root_block);
// flush
program.Proto();
......
......@@ -24,15 +24,6 @@ class FeedOp : public framework::OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output should be not null.");
int col = ctx->Attrs().Get<int>("col");
framework::Variable* g_feed_variable =
framework::GetGlobalScope()->FindVar("feed_value");
const auto& tensors =
g_feed_variable->Get<std::vector<framework::Tensor>>();
PADDLE_ENFORCE_GT(tensors.size(), static_cast<size_t>(col));
auto& shape = ctx->Attrs().Get<std::vector<int>>("dims");
std::vector<int64_t> shape_int64(shape.size(), 0);
std::transform(shape.begin(), shape.end(), shape_int64.begin(),
......@@ -43,7 +34,7 @@ class FeedOp : public framework::OperatorWithKernel {
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return static_cast<framework::DataType>(Attr<int>("data_type"));
return static_cast<framework::DataType>(Attr<int>("dataType"));
}
};
......@@ -51,7 +42,7 @@ class FeedOpMaker : public framework::OpProtoAndCheckerMaker {
public:
FeedOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<int>("data_type", "output data type")
AddAttr<int>("dataType", "output data type")
.SetDefault(framework::DataType::FP32);
AddAttr<int>("col", "The col in global feed variable").SetDefault(0);
AddAttr<std::vector<int>>("dims", "The dimension of feed tensor.");
......
......@@ -27,9 +27,10 @@ class FeedKernel : public framework::OpKernel<T> {
out->mutable_data<T>(ctx.GetPlace());
framework::Variable* g_feed_variable =
framework::GetGlobalScope()->FindVar("feed_value");
int col = ctx.template Attr<int>("col");
const auto& tensors =
g_feed_variable->Get<std::vector<framework::Tensor>>();
int col = ctx.template Attr<int>("col");
PADDLE_ENFORCE_GT(tensors.size(), static_cast<size_t>(col));
out->CopyFrom<T>(tensors[col], ctx.GetPlace());
}
};
......
......@@ -24,26 +24,11 @@ class FetchOp : public framework::OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"), "Input should be not null.");
int col = ctx->Attrs().Get<int>("col");
framework::Variable* g_fetch_variable =
framework::GetGlobalScope()->FindVar("fetch_value");
auto* tensors =
g_fetch_variable->GetMutable<std::vector<framework::Tensor>>();
if (tensors->size() < static_cast<size_t>(col + 1)) {
tensors->resize(col + 1);
}
auto input_dim = ctx->GetInputDim("Input");
PADDLE_ENFORCE_GT(tensors->size(), col);
(*tensors)[col].Resize(input_dim);
// TODO(qijun): need to handle LodTensor later
}
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return static_cast<framework::DataType>(Attr<int>("data_type"));
return static_cast<framework::DataType>(Attr<int>("dataType"));
}
};
......@@ -51,10 +36,9 @@ class FetchOpMaker : public framework::OpProtoAndCheckerMaker {
public:
FetchOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<int>("data_type", "output data type")
AddAttr<int>("dataType", "output data type")
.SetDefault(framework::DataType::FP32);
AddAttr<int>("col", "The col in global fetch variable").SetDefault(0);
AddAttr<std::vector<int>>("dims", "The dimension of fetch tensor.");
AddInput("Input", "The output of fetch op.");
AddComment(R"DOC(Fetch data to global fetch variable)DOC");
}
......
......@@ -24,13 +24,19 @@ class FetchKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const framework::Tensor* input = ctx.Input<framework::Tensor>("Input");
int col = ctx.template Attr<int>("col");
framework::Variable* g_fetch_variable =
framework::GetGlobalScope()->FindVar("fetch_value");
auto* tensors =
g_fetch_variable->GetMutable<std::vector<framework::Tensor>>();
int col = ctx.template Attr<int>("col");
if (tensors->size() < static_cast<size_t>(col + 1)) {
tensors->resize(col + 1);
}
PADDLE_ENFORCE_GT(tensors->size(), static_cast<size_t>(col));
(*tensors)[col].Resize(input->dims());
(*tensors)[col].mutable_data<T>(platform::CPUPlace());
(*tensors)[col].CopyFrom<T>(*input, platform::CPUPlace());
// TODO(qijun): need to handle LodTensor later
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册