提交 175aa7ea 编写于 作者: F fengjiayi 提交者: Yi Wang

add lod and dtype inference (#8329)

上级 74492d5d
......@@ -77,6 +77,8 @@ class CompileTimeInferShapeContext : public InferShapeContext {
void SetRepeatedDims(const std::string &name,
const std::vector<DDim> &dims) override;
InferShapeVarPtr GetVarPtr(const std::string &name) override;
const OpDesc &op_;
const BlockDesc &block_;
};
......@@ -510,5 +512,10 @@ proto::VarDesc::VarType CompileTimeInferShapeContext::GetVarType(
return block_.FindVarRecursive(name)->GetType();
}
InferShapeVarPtr CompileTimeInferShapeContext::GetVarPtr(
const std::string &name) {
return block_.FindVarRecursive(name);
}
} // namespace framework
} // namespace paddle
......@@ -470,6 +470,10 @@ class RuntimeInferShapeContext : public InferShapeContext {
return ToVarType(var->Type());
}
InferShapeVarPtr GetVarPtr(const std::string& name) override {
return scope_.FindVar(name);
}
private:
const OperatorBase& op_;
const Scope& scope_;
......
......@@ -90,7 +90,6 @@ void BatchReader::ReadNext(std::vector<LoDTensor>* out) {
// Merge lod and data
LoD batch_lod;
std::vector<size_t> top_level_lod({0});
for (size_t i = 0; i < buffer_.size(); ++i) {
DDim ins_shape = buffer_[i][j].dims();
LoD ins_lod = buffer_[i][j].lod();
......@@ -105,15 +104,10 @@ void BatchReader::ReadNext(std::vector<LoDTensor>* out) {
}
}
}
top_level_lod.push_back(
top_level_lod.back() +
(ins_lod.empty() ? ins_shape[0] : (ins_lod[0].size() - 1)));
Tensor dst = out_tensor.Slice(dst_offset, dst_offset + ins_shape[0]);
Copy(buffer_[i][j], platform::CPUPlace(), &dst);
dst_offset += ins_shape[0];
}
batch_lod.insert(batch_lod.begin(), top_level_lod);
out_tensor.set_lod(batch_lod);
out->push_back(out_tensor);
}
......
......@@ -72,6 +72,28 @@ void InferShapeContext::SetReaderDims(const std::string &name,
return this->SetRepeatedDims(arg_names[0], dims);
}
std::vector<InferShapeVarPtr> InferShapeContext::GetInputVarPtrs(
const std::string &name) {
const std::vector<std::string> arg_names = Inputs(name);
std::vector<InferShapeVarPtr> res;
res.reserve(arg_names.size());
std::transform(
arg_names.begin(), arg_names.end(), std::back_inserter(res),
[this](const std::string &name) { return this->GetVarPtr(name); });
return res;
}
std::vector<InferShapeVarPtr> InferShapeContext::GetOutputVarPtrs(
const std::string &name) {
const std::vector<std::string> arg_names = Outputs(name);
std::vector<InferShapeVarPtr> res;
res.reserve(arg_names.size());
std::transform(
arg_names.begin(), arg_names.end(), std::back_inserter(res),
[this](const std::string &name) { return this->GetVarPtr(name); });
return res;
}
std::vector<DDim> InferShapeContext::GetDims(
const std::vector<std::string> &names) const {
std::vector<DDim> ret;
......
......@@ -17,10 +17,14 @@ limitations under the License. */
#include "paddle/framework/attribute.h"
#include "paddle/framework/ddim.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/var_desc.h"
#include "paddle/framework/variable.h"
namespace paddle {
namespace framework {
using InferShapeVarPtr = boost::variant<VarDesc *, Variable *>;
class InferShapeContext {
public:
virtual ~InferShapeContext() = default;
......@@ -55,6 +59,9 @@ class InferShapeContext {
virtual bool IsRuntime() const = 0;
std::vector<InferShapeVarPtr> GetInputVarPtrs(const std::string &name);
std::vector<InferShapeVarPtr> GetOutputVarPtrs(const std::string &name);
// Note: In while op, we need this to be public
void SetDims(const std::vector<std::string> &names,
const std::vector<DDim> &dims);
......@@ -67,10 +74,13 @@ class InferShapeContext {
const std::vector<DDim> &dims) = 0;
std::vector<DDim> GetDims(const std::vector<std::string> &names) const;
std::vector<proto::VarDesc::VarType> GetVarTypes(
const std::vector<std::string> &names) const;
virtual proto::VarDesc::VarType GetVarType(const std::string &name) const = 0;
virtual InferShapeVarPtr GetVarPtr(const std::string &name) = 0;
};
} // namespace framework
......
......@@ -42,6 +42,18 @@ class CreateFileReaderInferShape : public framework::InferShapeBase {
const auto ranks = ctx->Attrs().Get<std::vector<int>>("ranks");
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
ctx->SetReaderDims("Out", shapes);
if (ctx->IsRuntime()) {
const auto lod_levels = ctx->Attrs().Get<std::vector<int>>("lod_levels");
PADDLE_ENFORCE_EQ(
lod_levels.size(), shapes.size(),
"The number of 'lod_levels'(%d) doesn't match the number "
"of 'shapes'(%d).",
lod_levels.size(), shapes.size());
framework::VarDesc* reader =
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
reader->SetLoDLevels(lod_levels);
}
}
};
......@@ -54,11 +66,19 @@ class CreateDecoratedReaderInferShape : public framework::InferShapeBase {
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"The output decorated reader should not be null.");
ctx->SetReaderDims("Out", ctx->GetReaderDims("UnderlyingReader"));
if (ctx->IsRuntime()) {
framework::VarDesc* in_reader = boost::get<framework::VarDesc*>(
ctx->GetInputVarPtrs("UnderlyingReader")[0]);
framework::VarDesc* out_reader =
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
out_reader->SetLoDLevels(in_reader->GetLoDLevels());
}
}
};
// general var type inference for all readers
class CreateReaderInferVarType : public framework::VarTypeInference {
// general var type inference for file readers
class CreateFileReaderInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
......@@ -68,6 +88,20 @@ class CreateReaderInferVarType : public framework::VarTypeInference {
}
};
// general var type inference for decorated readers
class CreateDecoratedReaderInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
std::string in_reader_name = op_desc.Input("UnderlyingReader")[0];
framework::VarDesc* in_reader = block->FindVarRecursive(in_reader_name);
std::string out_reader_name = op_desc.Output("Out")[0];
framework::VarDesc* out_reader = block->FindVarRecursive(out_reader_name);
out_reader->SetType(framework::proto::VarDesc::READER);
out_reader->SetDataTypes(in_reader->GetDataTypes());
}
};
template <typename T>
class CreateRandomDataGeneratorOp : public framework::OperatorBase {
public:
......@@ -105,6 +139,7 @@ class CreateRandomDataGeneratorOpMaker
"ranks = [3,2]"
"It means the reader will generate two data each time,"
"whose shapes are [2,3,4] and [5,6] respectively.");
AddAttr<std::vector<int>>("lod_levels", "The LoD levels of each data.");
AddAttr<float>("min", "The lower bound of reader's uniform distribution.");
AddAttr<float>("max", "The upper bound of reader's uniform distribution.");
AddComment(R"DOC(
......@@ -192,14 +227,14 @@ REGISTER_OPERATOR(create_random_data_generator,
ops::CreateFileReaderInferShape,
ops::CreateRandomDataGeneratorOpMaker,
paddle::framework::EmptyGradOpMaker,
ops::CreateReaderInferVarType);
ops::CreateFileReaderInferVarType);
REGISTER_OPERATOR(create_shuffle_reader, ops::CreateShuffleReaderOp,
ops::CreateDecoratedReaderInferShape,
ops::CreateShuffleReaderOpMaker,
paddle::framework::EmptyGradOpMaker,
ops::CreateReaderInferVarType);
ops::CreateDecoratedReaderInferVarType);
REGISTER_OPERATOR(create_batch_reader, ops::CreateBatchReaderOp,
ops::CreateDecoratedReaderInferShape,
ops::CreateBatchReaderOpMaker,
paddle::framework::EmptyGradOpMaker,
ops::CreateReaderInferVarType);
ops::CreateDecoratedReaderInferVarType);
......@@ -21,7 +21,8 @@ block = prog.current_block()
random_reader = block.create_var(
type=fluid.core.VarDesc.VarType.READER, name="RandomDataGenerator")
random_reader.desc.set_lod_levels([0, 0])
random_reader.desc.set_dtypes(
[fluid.core.DataType.FP32, fluid.core.DataType.FP32])
create_random_data_generator_op = block.append_op(
type="create_random_data_generator",
......@@ -30,11 +31,11 @@ create_random_data_generator_op = block.append_op(
"shape_concat": [1, 2, 1, 1],
"ranks": [2, 2],
"min": 0.0,
"max": 1.0
"max": 1.0,
'lod_levels': [0, 0]
})
shuffle_reader = block.create_var(
type=fluid.core.VarDesc.VarType.READER, name="ShuffleReader")
shuffle_reader.desc.set_lod_levels([0, 0])
create_shuffle_reader_op = block.append_op(
type="create_shuffle_reader",
......@@ -44,7 +45,6 @@ create_shuffle_reader_op = block.append_op(
batch_reader = block.create_var(
type=fluid.core.VarDesc.VarType.READER, name="BatchReader")
batch_reader.desc.set_lod_levels([1, 1])
create_batch_reader_op = block.append_op(
type="create_batch_reader",
......@@ -62,11 +62,9 @@ read_op = block.append_op(
place = fluid.CPUPlace()
exe = fluid.Executor(place)
[res1, res2] = exe.run(prog, fetch_list=[out1, out2], return_numpy=False)
[res1, res2] = exe.run(prog, fetch_list=[out1, out2])
test_pass = res1.lod() == [range(0, 11)] and res1.lod() == [
range(0, 11)
] and np.array(res1).shape == (10, 2) and np.array(res2).shape == (10, 1)
test_pass = res1.shape == (10, 2) and res2.shape == (10, 1)
if not test_pass:
exit(1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册