未验证 提交 32d50864 编写于 作者: F fengjiayi 提交者: GitHub

Merge pull request #11042 from JiayiFeng/fix_two_bugs

fix two bugs
...@@ -23,13 +23,12 @@ namespace reader { ...@@ -23,13 +23,12 @@ namespace reader {
class CustomReader : public framework::DecoratedReader { class CustomReader : public framework::DecoratedReader {
public: public:
CustomReader(ReaderBase* reader, const framework::BlockDesc& sub_block, CustomReader(ReaderBase* reader, const framework::BlockDesc& sub_block,
const platform::Place& dev_place,
const std::vector<std::string>& source_var_names, const std::vector<std::string>& source_var_names,
const std::vector<std::string>& sink_var_names) const std::vector<std::string>& sink_var_names)
: DecoratedReader(reader), : DecoratedReader(reader),
program_(*sub_block.Program()), program_(*sub_block.Program()),
sub_block_id_(sub_block.ID()), sub_block_id_(sub_block.ID()),
exe_(framework::Executor(dev_place)), exe_(framework::Executor(platform::CPUPlace())),
source_var_names_(source_var_names), source_var_names_(source_var_names),
sink_var_names_(sink_var_names) {} sink_var_names_(sink_var_names) {}
...@@ -60,7 +59,7 @@ class CreateCustomReaderOp : public framework::OperatorBase { ...@@ -60,7 +59,7 @@ class CreateCustomReaderOp : public framework::OperatorBase {
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>(); ->Get<framework::ReaderHolder>();
out->Reset( out->Reset(
new CustomReader(underlying_reader.Get(), *sub_block, dev_place, new CustomReader(underlying_reader.Get(), *sub_block,
Attr<std::vector<std::string>>("source_var_names"), Attr<std::vector<std::string>>("source_var_names"),
Attr<std::vector<std::string>>("sink_var_names"))); Attr<std::vector<std::string>>("sink_var_names")));
} }
...@@ -85,9 +84,10 @@ class CreateCustomReaderOpMaker : public DecoratedReaderMakerBase { ...@@ -85,9 +84,10 @@ class CreateCustomReaderOpMaker : public DecoratedReaderMakerBase {
CreateCustomReader Operator CreateCustomReader Operator
A custom reader can be used for input data preprocessing. A custom reader can be used for input data preprocessing.
A custom reader holds its own sub-block, which will be executed in its A custom reader holds its own sub-block, which will be executed in CPU
'ReadNext()' function. Users can configurate their own preprocessing in its 'ReadNext()' function. Users can configurate their own
pipelines by inserting operators into custom reader's sub-block. preprocessing pipelines by inserting operators into custom reader's
sub-block.
)DOC"); )DOC");
} }
}; };
......
...@@ -4009,7 +4009,8 @@ def random_crop(input, shape, seed=1): ...@@ -4009,7 +4009,8 @@ def random_crop(input, shape, seed=1):
attrs={ attrs={
"dtype": seed.dtype, "dtype": seed.dtype,
"shape": [1], "shape": [1],
"value": float(seed_value) "value": float(seed_value),
"force_cpu": True
}) })
elif not isinstance(seed, Variable): elif not isinstance(seed, Variable):
raise ValueError("'seed' must be a Variable or an int.") raise ValueError("'seed' must be a Variable or an int.")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册