提交 5c68765a 编写于 作者: Q qijun

fix persistable bug

上级 0910a8db
......@@ -34,6 +34,8 @@ class ElementwiseOp : public framework::OperatorWithKernel {
auto x_dim = ctx->GetInputDim("X");
auto y_dim = ctx->GetInputDim("Y");
LOG(INFO) << x_dim;
LOG(INFO) << y_dim;
PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(),
"Rank of first input must >= rank of second input.")
ctx->SetOutputDim("Out", x_dim);
......
......@@ -15,7 +15,7 @@ class Variable(object):
shape=None,
dtype=None,
lod_level=None,
persistable=False,
persistable=None,
**kwargs):
self.block = block
......
......@@ -121,10 +121,13 @@ class LayerHelper(object):
def create_tmp_variable(self, dtype):
return self.program.current_block().create_var(
name=unique_name(".".join([self.name, 'tmp'])), dtype=dtype)
name=unique_name(".".join([self.name, 'tmp'])),
persistable=False,
dtype=dtype)
def create_global_variable(self, *args, **kwargs):
return self.program.global_block().create_var(*args, **kwargs)
return self.program.global_block().create_var(
*args, persistable=False, **kwargs)
def append_bias_op(self, input_var):
size = list(input_var.shape[1:])
......
......@@ -37,13 +37,18 @@ for pass_id in range(PASS_NUM):
for data in train_reader():
x_data = np.array(map(lambda x: x[0], data)).astype("float32")
y_data = np.array(map(lambda x: x[1], data)).astype("float32")
y_data = np.expand_dims(y_data, axis=1)
# y_data = np.expand_dims(y_data, axis=1)
# print x_data
# print type(x_data)
# print y_data
tensor_x = core.LoDTensor()
tensor_x.set(x_data, place)
# print tensor_x.get_dims()
tensor_y = core.LoDTensor()
tensor_y.set(y_data, place)
# print tensor_y.get_dims()
outs = exe.run(program,
feed={'x': tensor_x,
'y': tensor_y},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册