未验证 提交 6776e928 编写于 作者: C chengduo 提交者: GitHub

refine tensor_array_write_read (#14643)

test=develop
上级 dfd4a111
...@@ -81,13 +81,35 @@ class CompileTimeInferShapeContext : public InferShapeContext { ...@@ -81,13 +81,35 @@ class CompileTimeInferShapeContext : public InferShapeContext {
"The %s[%d] is @EMPTY@", out, j); "The %s[%d] is @EMPTY@", out, j);
auto *in_var = block_.FindVarRecursive(Inputs(in)[i]); auto *in_var = block_.FindVarRecursive(Inputs(in)[i]);
auto *out_var = block_.FindVarRecursive(Outputs(out)[j]); auto *out_var = block_.FindVarRecursive(Outputs(out)[j]);
if (in_var->GetType() != proto::VarType::LOD_TENSOR) { if (in_var->GetType() != proto::VarType::LOD_TENSOR &&
VLOG(3) << "input " << in << " is not LodTensor"; in_var->GetType() != proto::VarType::LOD_TENSOR_ARRAY) {
VLOG(3) << "input " << in << " is not LodTensor or LodTensorArray.";
return; return;
} }
out_var->SetLoDLevel(in_var->GetLoDLevel()); out_var->SetLoDLevel(in_var->GetLoDLevel());
} }
void DecreaseLoDLevel(const std::string &in, const std::string &out,
size_t i = 0, size_t j = 0) const override {
PADDLE_ENFORCE_LT(i, Inputs(in).size());
PADDLE_ENFORCE_LT(j, Outputs(out).size());
PADDLE_ENFORCE(Inputs(in)[i] != framework::kEmptyVarName,
"The %s[%d] is @EMPTY@", in, i);
PADDLE_ENFORCE(Outputs(out)[j] != framework::kEmptyVarName,
"The %s[%d] is @EMPTY@", out, j);
auto *in_var = block_.FindVarRecursive(Inputs(in)[i]);
auto *out_var = block_.FindVarRecursive(Outputs(out)[j]);
PADDLE_ENFORCE(out_var->GetType() == proto::VarType::LOD_TENSOR_ARRAY ||
out_var->GetType() == proto::VarType::LOD_TENSOR,
"The input %s should be LodTensorArray or LodTensor.",
out_var->Name());
PADDLE_ENFORCE(in_var->GetType() == proto::VarType::LOD_TENSOR,
"The input %s should be LodTensor.", in_var->Name());
if (in_var->GetLoDLevel() > 0) {
out_var->SetLoDLevel(in_var->GetLoDLevel() - 1);
}
}
bool IsRuntime() const override; bool IsRuntime() const override;
protected: protected:
......
...@@ -623,6 +623,11 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -623,6 +623,11 @@ class RuntimeInferShapeContext : public InferShapeContext {
out_tensor->set_layout(in_tensor.layout()); out_tensor->set_layout(in_tensor.layout());
} }
void DecreaseLoDLevel(const std::string& in, const std::string& out,
size_t i = 0, size_t j = 0) const override {
PADDLE_THROW("DecreaseLoDLevel is only used in compile time.");
}
bool IsRuntime() const override { return true; } bool IsRuntime() const override { return true; }
protected: protected:
......
...@@ -62,6 +62,9 @@ class InferShapeContext { ...@@ -62,6 +62,9 @@ class InferShapeContext {
virtual void ShareLoD(const std::string &in, const std::string &out, virtual void ShareLoD(const std::string &in, const std::string &out,
size_t i = 0, size_t j = 0) const = 0; size_t i = 0, size_t j = 0) const = 0;
virtual void DecreaseLoDLevel(const std::string &in, const std::string &out,
size_t i = 0, size_t j = 0) const = 0;
virtual bool IsRuntime() const = 0; virtual bool IsRuntime() const = 0;
std::vector<InferShapeVarPtr> GetInputVarPtrs(const std::string &name); std::vector<InferShapeVarPtr> GetInputVarPtrs(const std::string &name);
......
...@@ -167,6 +167,19 @@ $$T = A[i]$$ ...@@ -167,6 +167,19 @@ $$T = A[i]$$
}; };
class ReadFromArrayInferShape : public WriteToArrayInferShape { class ReadFromArrayInferShape : public WriteToArrayInferShape {
public:
void operator()(framework::InferShapeContext *context) const override {
WriteToArrayInferShape::operator()(context);
if (!context->HasInput("X")) {
return;
}
// FIXME: just for compile time.
if (!context->IsRuntime()) {
context->ShareLoD("X", /*->*/ "Out");
}
}
protected: protected:
const char *NotHasXError() const override { const char *NotHasXError() const override {
return "The input array X must be set"; return "The input array X must be set";
......
...@@ -192,6 +192,10 @@ class LoDTensorToArrayInferShape : public framework::InferShapeBase { ...@@ -192,6 +192,10 @@ class LoDTensorToArrayInferShape : public framework::InferShapeBase {
// The first dim of each LoDTensor in Output can only be set at run-time.; // The first dim of each LoDTensor in Output can only be set at run-time.;
// We still have to Resize each LoDTensor in Output. // We still have to Resize each LoDTensor in Output.
context->SetOutputDim("Out", x_dim); context->SetOutputDim("Out", x_dim);
// The lod level should be passed to out in compile time.
if (!context->IsRuntime()) {
context->DecreaseLoDLevel("X", /*->*/ "Out");
}
} }
}; };
......
...@@ -201,6 +201,9 @@ class IdentityInferShape : public framework::InferShapeBase { ...@@ -201,6 +201,9 @@ class IdentityInferShape : public framework::InferShapeBase {
public: public:
void operator()(framework::InferShapeContext *context) const override { void operator()(framework::InferShapeContext *context) const override {
context->SetOutputDim("Out", context->GetInputDim("X")); context->SetOutputDim("Out", context->GetInputDim("X"));
if (!context->IsRuntime()) {
context->ShareLoD("X", /*->*/ "Out");
}
} }
}; };
......
...@@ -100,6 +100,9 @@ class ShrinkRNNMemoryInferShape : public framework::InferShapeBase { ...@@ -100,6 +100,9 @@ class ShrinkRNNMemoryInferShape : public framework::InferShapeBase {
PADDLE_ENFORCE(context->HasInput("I")); PADDLE_ENFORCE(context->HasInput("I"));
PADDLE_ENFORCE(context->HasInput("RankTable")); PADDLE_ENFORCE(context->HasInput("RankTable"));
context->SetOutputDim("Out", context->GetInputDim("X")); context->SetOutputDim("Out", context->GetInputDim("X"));
if (!context->IsRuntime()) {
context->DecreaseLoDLevel("X", /*->*/ "Out");
}
} }
}; };
......
...@@ -172,6 +172,7 @@ class TestDynRNN(unittest.TestCase): ...@@ -172,6 +172,7 @@ class TestDynRNN(unittest.TestCase):
rnn = fluid.layers.DynamicRNN() rnn = fluid.layers.DynamicRNN()
with rnn.block(): with rnn.block():
in_ = rnn.step_input(sentence) in_ = rnn.step_input(sentence)
assert in_.lod_level == 1, "the lod level of in_ should be 1"
sent_emb = fluid.layers.embedding( sent_emb = fluid.layers.embedding(
input=in_, size=[len(word_dict), 32], dtype='float32') input=in_, size=[len(word_dict), 32], dtype='float32')
out_ = fluid.layers.fc(input=sent_emb, size=100, act='tanh') out_ = fluid.layers.fc(input=sent_emb, size=100, act='tanh')
...@@ -179,6 +180,7 @@ class TestDynRNN(unittest.TestCase): ...@@ -179,6 +180,7 @@ class TestDynRNN(unittest.TestCase):
rnn1 = fluid.layers.DynamicRNN() rnn1 = fluid.layers.DynamicRNN()
with rnn1.block(): with rnn1.block():
in_1 = rnn1.step_input(out_) in_1 = rnn1.step_input(out_)
assert in_1.lod_level == 0, "the lod level of in_1 should be 0"
out_1 = fluid.layers.fc(input=[in_1], size=100, act='tanh') out_1 = fluid.layers.fc(input=[in_1], size=100, act='tanh')
rnn1.output(out_1) rnn1.output(out_1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册