提交 d2c1fac1 编写于 作者: Y yuyang18

Merge branch 'dev_add_random_crop_op' of https://github.com/JiayiFeng/Paddle...

Merge branch 'dev_add_random_crop_op' of https://github.com/JiayiFeng/Paddle into dev_add_random_crop_op
......@@ -466,6 +466,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
protected:
DDim GetDim(const std::string& name) const override {
Variable* var = scope_.FindVar(name);
PADDLE_ENFORCE_NOT_NULL(var);
if (var->IsType<LoDTensor>()) {
return var->Get<LoDTensor>().dims();
} else if (var->IsType<SelectedRows>()) {
......
......@@ -3996,13 +3996,19 @@ def random_crop(input, shape, seed=1):
out = helper.create_tmp_variable(dtype)
if isinstance(seed, int):
seed_value = seed
seed = helper.create_global_variable(
persistable=True, shape=[1], dtype="int32")
helper.set_variable_initializer(
var=seed, initializer=Constant(value=seed_value))
seed = helper.create_tmp_variable(dtype="int64")
helper.append_op(
type="fill_constant",
inputs={},
outputs={"Out": seed},
attrs={
"dtype": seed.dtype,
"shape": [1],
"value": float(seed_value)
})
elif not isinstance(seed, Variable):
raise ValueError("'seed' must be a Variable or an int.")
seed_out = helper.create_tmp_variable(dtype="int32")
seed_out = helper.create_tmp_variable(dtype="int64")
helper.append_op(
type="random_crop",
inputs={"X": input,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册