提交 225efa67 编写于 作者: Y Yu Yang

Remove dims in base class

上级 2ea4a5d9
...@@ -442,15 +442,7 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -442,15 +442,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
} }
std::vector<DDim> GetRepeatedDims(const std::string& name) const override { std::vector<DDim> GetRepeatedDims(const std::string& name) const override {
Variable* var = scope_.FindVar(name); PADDLE_THROW("Only compile time support this method");
if (var->IsType<ReaderHolder>()) {
return var->Get<ReaderHolder>().shapes();
} else {
PADDLE_THROW(
"Only ReaderHolder support 'GetRepeatedDims', but Variable %s's "
"type_id is %s.",
name, var->Type().name());
}
} }
void SetDim(const std::string& name, const DDim& dim) override { void SetDim(const std::string& name, const DDim& dim) override {
...@@ -467,15 +459,7 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -467,15 +459,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
void SetRepeatedDims(const std::string& name, void SetRepeatedDims(const std::string& name,
const std::vector<DDim>& dims) override { const std::vector<DDim>& dims) override {
Variable* var = scope_.FindVar(name); PADDLE_THROW("Only compile time support this method");
if (var->IsType<ReaderHolder>()) {
var->GetMutable<ReaderHolder>()->set_shapes(dims);
} else {
PADDLE_THROW(
"Only ReaderHolder support 'SetRepeatedDims', but Variable %s's "
"type_id is %s.",
name, var->Type().name());
}
} }
proto::VarType::Type GetVarType(const std::string& name) const override { proto::VarType::Type GetVarType(const std::string& name) const override {
......
...@@ -16,14 +16,6 @@ ...@@ -16,14 +16,6 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
ReaderBase::~ReaderBase() {}
DDim ReaderBase::shape(size_t idx) const {
PADDLE_ENFORCE_LT(
idx, shapes_.size(),
"Cannot get the %d'th shape, 'shapes_' only has %d elements.", idx,
shapes_.size());
return shapes_[idx];
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -22,34 +22,18 @@ namespace framework { ...@@ -22,34 +22,18 @@ namespace framework {
class ReaderBase { class ReaderBase {
public: public:
explicit ReaderBase(const std::vector<DDim>& shapes) : shapes_(shapes) {
PADDLE_ENFORCE(!shapes_.empty());
}
virtual void ReadNext(std::vector<LoDTensor>* out) = 0; virtual void ReadNext(std::vector<LoDTensor>* out) = 0;
virtual void ReInit() = 0; virtual void ReInit() = 0;
DDim shape(size_t idx) const;
std::vector<DDim> shapes() const { return shapes_; }
void set_shapes(const std::vector<DDim>& shapes) { shapes_ = shapes; }
virtual bool HasNext() const = 0; virtual bool HasNext() const = 0;
virtual ~ReaderBase() {} virtual ~ReaderBase();
protected:
std::vector<DDim> shapes_;
};
class FileReader : public ReaderBase {
public:
explicit FileReader(const std::vector<DDim>& shapes) : ReaderBase(shapes) {}
}; };
class DecoratedReader : public ReaderBase { class DecoratedReader : public ReaderBase {
public: public:
explicit DecoratedReader(ReaderBase* reader) explicit DecoratedReader(ReaderBase* reader) : ReaderBase(), reader_(reader) {
: ReaderBase(reader->shapes()), reader_(reader) {
PADDLE_ENFORCE_NOT_NULL(reader_); PADDLE_ENFORCE_NOT_NULL(reader_);
} }
...@@ -72,12 +56,6 @@ class ReaderHolder { ...@@ -72,12 +56,6 @@ class ReaderHolder {
void ReadNext(std::vector<LoDTensor>* out) { reader_->ReadNext(out); } void ReadNext(std::vector<LoDTensor>* out) { reader_->ReadNext(out); }
void ReInit() { reader_->ReInit(); } void ReInit() { reader_->ReInit(); }
DDim shape(size_t idx) const { return reader_->shape(idx); }
std::vector<DDim> shapes() const { return reader_->shapes(); }
void set_shapes(const std::vector<DDim>& shapes) {
reader_->set_shapes(shapes);
}
bool HasNext() const { return reader_->HasNext(); } bool HasNext() const { return reader_->HasNext(); }
private: private:
......
...@@ -19,11 +19,11 @@ namespace operators { ...@@ -19,11 +19,11 @@ namespace operators {
namespace reader { namespace reader {
template <typename T> template <typename T>
class RandomDataGenerator : public framework::FileReader { class RandomDataGenerator : public framework::ReaderBase {
public: public:
RandomDataGenerator(const std::vector<framework::DDim>& shapes, float min, RandomDataGenerator(const std::vector<framework::DDim>& shapes, float min,
float max) float max)
: FileReader(shapes), min_(min), max_(max) { : framework::ReaderBase(), min_(min), max_(max), shapes_(shapes) {
PADDLE_ENFORCE_LE( PADDLE_ENFORCE_LE(
min, max, "'min' shouldn't be greater than 'max'.(%f vs %f)", min, max); min, max, "'min' shouldn't be greater than 'max'.(%f vs %f)", min, max);
unsigned int seed = std::random_device()(); unsigned int seed = std::random_device()();
...@@ -59,6 +59,7 @@ class RandomDataGenerator : public framework::FileReader { ...@@ -59,6 +59,7 @@ class RandomDataGenerator : public framework::FileReader {
float max_; float max_;
std::minstd_rand engine_; std::minstd_rand engine_;
std::uniform_real_distribution<float> dist_; std::uniform_real_distribution<float> dist_;
std::vector<framework::DDim> shapes_;
}; };
template <typename T> template <typename T>
......
...@@ -18,11 +18,10 @@ ...@@ -18,11 +18,10 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace reader { namespace reader {
class RecordIOFileReader : public framework::FileReader { class RecordIOFileReader : public framework::ReaderBase {
public: public:
RecordIOFileReader(const std::string& filename, explicit RecordIOFileReader(const std::string& filename)
const std::vector<framework::DDim>& shapes) : ReaderBase(),
: FileReader(shapes),
scanner_(filename), scanner_(filename),
dev_ctx_(*platform::DeviceContextPool::Instance().Get( dev_ctx_(*platform::DeviceContextPool::Instance().Get(
platform::CPUPlace())) {} platform::CPUPlace())) {}
...@@ -54,12 +53,11 @@ class CreateRecordIOReaderOp : public framework::OperatorBase { ...@@ -54,12 +53,11 @@ class CreateRecordIOReaderOp : public framework::OperatorBase {
int(shape_concat.size()), int(shape_concat.size()),
"The accumulate of all ranks should be equal to the " "The accumulate of all ranks should be equal to the "
"shape concat's length."); "shape concat's length.");
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
std::string filename = Attr<std::string>("filename"); std::string filename = Attr<std::string>("filename");
auto* out = scope.FindVar(Output("Out")) auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>(); ->template GetMutable<framework::ReaderHolder>();
out->Reset(new RecordIOFileReader(filename, shapes)); out->Reset(new RecordIOFileReader(filename));
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册