提交 175aa7ea 编写于 作者: F fengjiayi 提交者: Yi Wang

add lod and dtype inference (#8329)

上级 74492d5d
...@@ -77,6 +77,8 @@ class CompileTimeInferShapeContext : public InferShapeContext { ...@@ -77,6 +77,8 @@ class CompileTimeInferShapeContext : public InferShapeContext {
void SetRepeatedDims(const std::string &name, void SetRepeatedDims(const std::string &name,
const std::vector<DDim> &dims) override; const std::vector<DDim> &dims) override;
InferShapeVarPtr GetVarPtr(const std::string &name) override;
const OpDesc &op_; const OpDesc &op_;
const BlockDesc &block_; const BlockDesc &block_;
}; };
...@@ -510,5 +512,10 @@ proto::VarDesc::VarType CompileTimeInferShapeContext::GetVarType( ...@@ -510,5 +512,10 @@ proto::VarDesc::VarType CompileTimeInferShapeContext::GetVarType(
return block_.FindVarRecursive(name)->GetType(); return block_.FindVarRecursive(name)->GetType();
} }
InferShapeVarPtr CompileTimeInferShapeContext::GetVarPtr(
const std::string &name) {
return block_.FindVarRecursive(name);
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -470,6 +470,10 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -470,6 +470,10 @@ class RuntimeInferShapeContext : public InferShapeContext {
return ToVarType(var->Type()); return ToVarType(var->Type());
} }
InferShapeVarPtr GetVarPtr(const std::string& name) override {
return scope_.FindVar(name);
}
private: private:
const OperatorBase& op_; const OperatorBase& op_;
const Scope& scope_; const Scope& scope_;
......
...@@ -90,7 +90,6 @@ void BatchReader::ReadNext(std::vector<LoDTensor>* out) { ...@@ -90,7 +90,6 @@ void BatchReader::ReadNext(std::vector<LoDTensor>* out) {
// Merge lod and data // Merge lod and data
LoD batch_lod; LoD batch_lod;
std::vector<size_t> top_level_lod({0});
for (size_t i = 0; i < buffer_.size(); ++i) { for (size_t i = 0; i < buffer_.size(); ++i) {
DDim ins_shape = buffer_[i][j].dims(); DDim ins_shape = buffer_[i][j].dims();
LoD ins_lod = buffer_[i][j].lod(); LoD ins_lod = buffer_[i][j].lod();
...@@ -105,15 +104,10 @@ void BatchReader::ReadNext(std::vector<LoDTensor>* out) { ...@@ -105,15 +104,10 @@ void BatchReader::ReadNext(std::vector<LoDTensor>* out) {
} }
} }
} }
top_level_lod.push_back(
top_level_lod.back() +
(ins_lod.empty() ? ins_shape[0] : (ins_lod[0].size() - 1)));
Tensor dst = out_tensor.Slice(dst_offset, dst_offset + ins_shape[0]); Tensor dst = out_tensor.Slice(dst_offset, dst_offset + ins_shape[0]);
Copy(buffer_[i][j], platform::CPUPlace(), &dst); Copy(buffer_[i][j], platform::CPUPlace(), &dst);
dst_offset += ins_shape[0]; dst_offset += ins_shape[0];
} }
batch_lod.insert(batch_lod.begin(), top_level_lod);
out_tensor.set_lod(batch_lod); out_tensor.set_lod(batch_lod);
out->push_back(out_tensor); out->push_back(out_tensor);
} }
......
...@@ -72,6 +72,28 @@ void InferShapeContext::SetReaderDims(const std::string &name, ...@@ -72,6 +72,28 @@ void InferShapeContext::SetReaderDims(const std::string &name,
return this->SetRepeatedDims(arg_names[0], dims); return this->SetRepeatedDims(arg_names[0], dims);
} }
std::vector<InferShapeVarPtr> InferShapeContext::GetInputVarPtrs(
const std::string &name) {
const std::vector<std::string> arg_names = Inputs(name);
std::vector<InferShapeVarPtr> res;
res.reserve(arg_names.size());
std::transform(
arg_names.begin(), arg_names.end(), std::back_inserter(res),
[this](const std::string &name) { return this->GetVarPtr(name); });
return res;
}
std::vector<InferShapeVarPtr> InferShapeContext::GetOutputVarPtrs(
const std::string &name) {
const std::vector<std::string> arg_names = Outputs(name);
std::vector<InferShapeVarPtr> res;
res.reserve(arg_names.size());
std::transform(
arg_names.begin(), arg_names.end(), std::back_inserter(res),
[this](const std::string &name) { return this->GetVarPtr(name); });
return res;
}
std::vector<DDim> InferShapeContext::GetDims( std::vector<DDim> InferShapeContext::GetDims(
const std::vector<std::string> &names) const { const std::vector<std::string> &names) const {
std::vector<DDim> ret; std::vector<DDim> ret;
......
...@@ -17,10 +17,14 @@ limitations under the License. */ ...@@ -17,10 +17,14 @@ limitations under the License. */
#include "paddle/framework/attribute.h" #include "paddle/framework/attribute.h"
#include "paddle/framework/ddim.h" #include "paddle/framework/ddim.h"
#include "paddle/framework/framework.pb.h" #include "paddle/framework/framework.pb.h"
#include "paddle/framework/var_desc.h"
#include "paddle/framework/variable.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
using InferShapeVarPtr = boost::variant<VarDesc *, Variable *>;
class InferShapeContext { class InferShapeContext {
public: public:
virtual ~InferShapeContext() = default; virtual ~InferShapeContext() = default;
...@@ -55,6 +59,9 @@ class InferShapeContext { ...@@ -55,6 +59,9 @@ class InferShapeContext {
virtual bool IsRuntime() const = 0; virtual bool IsRuntime() const = 0;
std::vector<InferShapeVarPtr> GetInputVarPtrs(const std::string &name);
std::vector<InferShapeVarPtr> GetOutputVarPtrs(const std::string &name);
// Note: In while op, we need this to be public // Note: In while op, we need this to be public
void SetDims(const std::vector<std::string> &names, void SetDims(const std::vector<std::string> &names,
const std::vector<DDim> &dims); const std::vector<DDim> &dims);
...@@ -67,10 +74,13 @@ class InferShapeContext { ...@@ -67,10 +74,13 @@ class InferShapeContext {
const std::vector<DDim> &dims) = 0; const std::vector<DDim> &dims) = 0;
std::vector<DDim> GetDims(const std::vector<std::string> &names) const; std::vector<DDim> GetDims(const std::vector<std::string> &names) const;
std::vector<proto::VarDesc::VarType> GetVarTypes( std::vector<proto::VarDesc::VarType> GetVarTypes(
const std::vector<std::string> &names) const; const std::vector<std::string> &names) const;
virtual proto::VarDesc::VarType GetVarType(const std::string &name) const = 0; virtual proto::VarDesc::VarType GetVarType(const std::string &name) const = 0;
virtual InferShapeVarPtr GetVarPtr(const std::string &name) = 0;
}; };
} // namespace framework } // namespace framework
......
...@@ -42,6 +42,18 @@ class CreateFileReaderInferShape : public framework::InferShapeBase { ...@@ -42,6 +42,18 @@ class CreateFileReaderInferShape : public framework::InferShapeBase {
const auto ranks = ctx->Attrs().Get<std::vector<int>>("ranks"); const auto ranks = ctx->Attrs().Get<std::vector<int>>("ranks");
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks); std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
ctx->SetReaderDims("Out", shapes); ctx->SetReaderDims("Out", shapes);
if (ctx->IsRuntime()) {
const auto lod_levels = ctx->Attrs().Get<std::vector<int>>("lod_levels");
PADDLE_ENFORCE_EQ(
lod_levels.size(), shapes.size(),
"The number of 'lod_levels'(%d) doesn't match the number "
"of 'shapes'(%d).",
lod_levels.size(), shapes.size());
framework::VarDesc* reader =
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
reader->SetLoDLevels(lod_levels);
}
} }
}; };
...@@ -54,11 +66,19 @@ class CreateDecoratedReaderInferShape : public framework::InferShapeBase { ...@@ -54,11 +66,19 @@ class CreateDecoratedReaderInferShape : public framework::InferShapeBase {
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"The output decorated reader should not be null."); "The output decorated reader should not be null.");
ctx->SetReaderDims("Out", ctx->GetReaderDims("UnderlyingReader")); ctx->SetReaderDims("Out", ctx->GetReaderDims("UnderlyingReader"));
if (ctx->IsRuntime()) {
framework::VarDesc* in_reader = boost::get<framework::VarDesc*>(
ctx->GetInputVarPtrs("UnderlyingReader")[0]);
framework::VarDesc* out_reader =
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
out_reader->SetLoDLevels(in_reader->GetLoDLevels());
}
} }
}; };
// general var type inference for all readers // general var type inference for file readers
class CreateReaderInferVarType : public framework::VarTypeInference { class CreateFileReaderInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc& op_desc, void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override { framework::BlockDesc* block) const override {
...@@ -68,6 +88,20 @@ class CreateReaderInferVarType : public framework::VarTypeInference { ...@@ -68,6 +88,20 @@ class CreateReaderInferVarType : public framework::VarTypeInference {
} }
}; };
// general var type inference for decorated readers
class CreateDecoratedReaderInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
std::string in_reader_name = op_desc.Input("UnderlyingReader")[0];
framework::VarDesc* in_reader = block->FindVarRecursive(in_reader_name);
std::string out_reader_name = op_desc.Output("Out")[0];
framework::VarDesc* out_reader = block->FindVarRecursive(out_reader_name);
out_reader->SetType(framework::proto::VarDesc::READER);
out_reader->SetDataTypes(in_reader->GetDataTypes());
}
};
template <typename T> template <typename T>
class CreateRandomDataGeneratorOp : public framework::OperatorBase { class CreateRandomDataGeneratorOp : public framework::OperatorBase {
public: public:
...@@ -105,6 +139,7 @@ class CreateRandomDataGeneratorOpMaker ...@@ -105,6 +139,7 @@ class CreateRandomDataGeneratorOpMaker
"ranks = [3,2]" "ranks = [3,2]"
"It means the reader will generate two data each time," "It means the reader will generate two data each time,"
"whose shapes are [2,3,4] and [5,6] respectively."); "whose shapes are [2,3,4] and [5,6] respectively.");
AddAttr<std::vector<int>>("lod_levels", "The LoD levels of each data.");
AddAttr<float>("min", "The lower bound of reader's uniform distribution."); AddAttr<float>("min", "The lower bound of reader's uniform distribution.");
AddAttr<float>("max", "The upper bound of reader's uniform distribution."); AddAttr<float>("max", "The upper bound of reader's uniform distribution.");
AddComment(R"DOC( AddComment(R"DOC(
...@@ -192,14 +227,14 @@ REGISTER_OPERATOR(create_random_data_generator, ...@@ -192,14 +227,14 @@ REGISTER_OPERATOR(create_random_data_generator,
ops::CreateFileReaderInferShape, ops::CreateFileReaderInferShape,
ops::CreateRandomDataGeneratorOpMaker, ops::CreateRandomDataGeneratorOpMaker,
paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker,
ops::CreateReaderInferVarType); ops::CreateFileReaderInferVarType);
REGISTER_OPERATOR(create_shuffle_reader, ops::CreateShuffleReaderOp, REGISTER_OPERATOR(create_shuffle_reader, ops::CreateShuffleReaderOp,
ops::CreateDecoratedReaderInferShape, ops::CreateDecoratedReaderInferShape,
ops::CreateShuffleReaderOpMaker, ops::CreateShuffleReaderOpMaker,
paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker,
ops::CreateReaderInferVarType); ops::CreateDecoratedReaderInferVarType);
REGISTER_OPERATOR(create_batch_reader, ops::CreateBatchReaderOp, REGISTER_OPERATOR(create_batch_reader, ops::CreateBatchReaderOp,
ops::CreateDecoratedReaderInferShape, ops::CreateDecoratedReaderInferShape,
ops::CreateBatchReaderOpMaker, ops::CreateBatchReaderOpMaker,
paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker,
ops::CreateReaderInferVarType); ops::CreateDecoratedReaderInferVarType);
...@@ -21,7 +21,8 @@ block = prog.current_block() ...@@ -21,7 +21,8 @@ block = prog.current_block()
random_reader = block.create_var( random_reader = block.create_var(
type=fluid.core.VarDesc.VarType.READER, name="RandomDataGenerator") type=fluid.core.VarDesc.VarType.READER, name="RandomDataGenerator")
random_reader.desc.set_lod_levels([0, 0]) random_reader.desc.set_dtypes(
[fluid.core.DataType.FP32, fluid.core.DataType.FP32])
create_random_data_generator_op = block.append_op( create_random_data_generator_op = block.append_op(
type="create_random_data_generator", type="create_random_data_generator",
...@@ -30,11 +31,11 @@ create_random_data_generator_op = block.append_op( ...@@ -30,11 +31,11 @@ create_random_data_generator_op = block.append_op(
"shape_concat": [1, 2, 1, 1], "shape_concat": [1, 2, 1, 1],
"ranks": [2, 2], "ranks": [2, 2],
"min": 0.0, "min": 0.0,
"max": 1.0 "max": 1.0,
'lod_levels': [0, 0]
}) })
shuffle_reader = block.create_var( shuffle_reader = block.create_var(
type=fluid.core.VarDesc.VarType.READER, name="ShuffleReader") type=fluid.core.VarDesc.VarType.READER, name="ShuffleReader")
shuffle_reader.desc.set_lod_levels([0, 0])
create_shuffle_reader_op = block.append_op( create_shuffle_reader_op = block.append_op(
type="create_shuffle_reader", type="create_shuffle_reader",
...@@ -44,7 +45,6 @@ create_shuffle_reader_op = block.append_op( ...@@ -44,7 +45,6 @@ create_shuffle_reader_op = block.append_op(
batch_reader = block.create_var( batch_reader = block.create_var(
type=fluid.core.VarDesc.VarType.READER, name="BatchReader") type=fluid.core.VarDesc.VarType.READER, name="BatchReader")
batch_reader.desc.set_lod_levels([1, 1])
create_batch_reader_op = block.append_op( create_batch_reader_op = block.append_op(
type="create_batch_reader", type="create_batch_reader",
...@@ -62,11 +62,9 @@ read_op = block.append_op( ...@@ -62,11 +62,9 @@ read_op = block.append_op(
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
[res1, res2] = exe.run(prog, fetch_list=[out1, out2], return_numpy=False) [res1, res2] = exe.run(prog, fetch_list=[out1, out2])
test_pass = res1.lod() == [range(0, 11)] and res1.lod() == [ test_pass = res1.shape == (10, 2) and res2.shape == (10, 1)
range(0, 11)
] and np.array(res1).shape == (10, 2) and np.array(res2).shape == (10, 1)
if not test_pass: if not test_pass:
exit(1) exit(1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册