提交 d2c1fac1 编写于 作者: Y yuyang18

Merge branch 'dev_add_random_crop_op' of https://github.com/JiayiFeng/Paddle...

Merge branch 'dev_add_random_crop_op' of https://github.com/JiayiFeng/Paddle into dev_add_random_crop_op
...@@ -466,6 +466,7 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -466,6 +466,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
protected: protected:
DDim GetDim(const std::string& name) const override { DDim GetDim(const std::string& name) const override {
Variable* var = scope_.FindVar(name); Variable* var = scope_.FindVar(name);
PADDLE_ENFORCE_NOT_NULL(var);
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
return var->Get<LoDTensor>().dims(); return var->Get<LoDTensor>().dims();
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<SelectedRows>()) {
......
...@@ -3996,13 +3996,19 @@ def random_crop(input, shape, seed=1): ...@@ -3996,13 +3996,19 @@ def random_crop(input, shape, seed=1):
out = helper.create_tmp_variable(dtype) out = helper.create_tmp_variable(dtype)
if isinstance(seed, int): if isinstance(seed, int):
seed_value = seed seed_value = seed
seed = helper.create_global_variable( seed = helper.create_tmp_variable(dtype="int64")
persistable=True, shape=[1], dtype="int32") helper.append_op(
helper.set_variable_initializer( type="fill_constant",
var=seed, initializer=Constant(value=seed_value)) inputs={},
outputs={"Out": seed},
attrs={
"dtype": seed.dtype,
"shape": [1],
"value": float(seed_value)
})
elif not isinstance(seed, Variable): elif not isinstance(seed, Variable):
raise ValueError("'seed' must be a Variable or an int.") raise ValueError("'seed' must be a Variable or an int.")
seed_out = helper.create_tmp_variable(dtype="int32") seed_out = helper.create_tmp_variable(dtype="int64")
helper.append_op( helper.append_op(
type="random_crop", type="random_crop",
inputs={"X": input, inputs={"X": input,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册