提交 225efa67 编写于 作者: Y Yu Yang

Remove dims in base class

上级 2ea4a5d9
......@@ -442,15 +442,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
}
std::vector<DDim> GetRepeatedDims(const std::string& name) const override {
Variable* var = scope_.FindVar(name);
if (var->IsType<ReaderHolder>()) {
return var->Get<ReaderHolder>().shapes();
} else {
PADDLE_THROW(
"Only ReaderHolder support 'GetRepeatedDims', but Variable %s's "
"type_id is %s.",
name, var->Type().name());
}
PADDLE_THROW("Only compile time support this method");
}
void SetDim(const std::string& name, const DDim& dim) override {
......@@ -467,15 +459,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
void SetRepeatedDims(const std::string& name,
const std::vector<DDim>& dims) override {
Variable* var = scope_.FindVar(name);
if (var->IsType<ReaderHolder>()) {
var->GetMutable<ReaderHolder>()->set_shapes(dims);
} else {
PADDLE_THROW(
"Only ReaderHolder support 'SetRepeatedDims', but Variable %s's "
"type_id is %s.",
name, var->Type().name());
}
PADDLE_THROW("Only compile time support this method");
}
proto::VarType::Type GetVarType(const std::string& name) const override {
......
......@@ -16,14 +16,6 @@
namespace paddle {
namespace framework {
DDim ReaderBase::shape(size_t idx) const {
PADDLE_ENFORCE_LT(
idx, shapes_.size(),
"Cannot get the %d'th shape, 'shapes_' only has %d elements.", idx,
shapes_.size());
return shapes_[idx];
}
ReaderBase::~ReaderBase() {}
} // namespace framework
} // namespace paddle
......@@ -22,34 +22,18 @@ namespace framework {
class ReaderBase {
public:
explicit ReaderBase(const std::vector<DDim>& shapes) : shapes_(shapes) {
PADDLE_ENFORCE(!shapes_.empty());
}
virtual void ReadNext(std::vector<LoDTensor>* out) = 0;
virtual void ReInit() = 0;
DDim shape(size_t idx) const;
std::vector<DDim> shapes() const { return shapes_; }
void set_shapes(const std::vector<DDim>& shapes) { shapes_ = shapes; }
virtual bool HasNext() const = 0;
virtual ~ReaderBase() {}
protected:
std::vector<DDim> shapes_;
};
class FileReader : public ReaderBase {
public:
explicit FileReader(const std::vector<DDim>& shapes) : ReaderBase(shapes) {}
virtual ~ReaderBase();
};
class DecoratedReader : public ReaderBase {
public:
explicit DecoratedReader(ReaderBase* reader)
: ReaderBase(reader->shapes()), reader_(reader) {
explicit DecoratedReader(ReaderBase* reader) : ReaderBase(), reader_(reader) {
PADDLE_ENFORCE_NOT_NULL(reader_);
}
......@@ -72,12 +56,6 @@ class ReaderHolder {
void ReadNext(std::vector<LoDTensor>* out) { reader_->ReadNext(out); }
void ReInit() { reader_->ReInit(); }
DDim shape(size_t idx) const { return reader_->shape(idx); }
std::vector<DDim> shapes() const { return reader_->shapes(); }
void set_shapes(const std::vector<DDim>& shapes) {
reader_->set_shapes(shapes);
}
bool HasNext() const { return reader_->HasNext(); }
private:
......
......@@ -19,11 +19,11 @@ namespace operators {
namespace reader {
template <typename T>
class RandomDataGenerator : public framework::FileReader {
class RandomDataGenerator : public framework::ReaderBase {
public:
RandomDataGenerator(const std::vector<framework::DDim>& shapes, float min,
float max)
: FileReader(shapes), min_(min), max_(max) {
: framework::ReaderBase(), min_(min), max_(max), shapes_(shapes) {
PADDLE_ENFORCE_LE(
min, max, "'min' shouldn't be greater than 'max'.(%f vs %f)", min, max);
unsigned int seed = std::random_device()();
......@@ -59,6 +59,7 @@ class RandomDataGenerator : public framework::FileReader {
float max_;
std::minstd_rand engine_;
std::uniform_real_distribution<float> dist_;
std::vector<framework::DDim> shapes_;
};
template <typename T>
......
......@@ -18,11 +18,10 @@
namespace paddle {
namespace operators {
namespace reader {
class RecordIOFileReader : public framework::FileReader {
class RecordIOFileReader : public framework::ReaderBase {
public:
RecordIOFileReader(const std::string& filename,
const std::vector<framework::DDim>& shapes)
: FileReader(shapes),
explicit RecordIOFileReader(const std::string& filename)
: ReaderBase(),
scanner_(filename),
dev_ctx_(*platform::DeviceContextPool::Instance().Get(
platform::CPUPlace())) {}
......@@ -54,12 +53,11 @@ class CreateRecordIOReaderOp : public framework::OperatorBase {
int(shape_concat.size()),
"The accumulate of all ranks should be equal to the "
"shape concat's length.");
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
std::string filename = Attr<std::string>("filename");
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(new RecordIOFileReader(filename, shapes));
out->Reset(new RecordIOFileReader(filename));
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册