提交 89e40949 编写于 作者: Q qijun

fix type

上级 a64c88e6
...@@ -34,8 +34,6 @@ class ElementwiseOp : public framework::OperatorWithKernel { ...@@ -34,8 +34,6 @@ class ElementwiseOp : public framework::OperatorWithKernel {
auto x_dim = ctx->GetInputDim("X"); auto x_dim = ctx->GetInputDim("X");
auto y_dim = ctx->GetInputDim("Y"); auto y_dim = ctx->GetInputDim("Y");
LOG(INFO) << x_dim;
LOG(INFO) << y_dim;
PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(), PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(),
"Rank of first input must >= rank of second input.") "Rank of first input must >= rank of second input.")
ctx->SetOutputDim("Out", x_dim); ctx->SetOutputDim("Out", x_dim);
......
...@@ -48,7 +48,7 @@ exe = Executor(place) ...@@ -48,7 +48,7 @@ exe = Executor(place)
exe.run(init_program, feed={}, fetch_list=[]) exe.run(init_program, feed={}, fetch_list=[])
PASS_NUM = 10 PASS_NUM = 100
for pass_id in range(PASS_NUM): for pass_id in range(PASS_NUM):
for data in train_reader(): for data in train_reader():
x_data = np.array(map(lambda x: x[0], data)).astype("float32") x_data = np.array(map(lambda x: x[0], data)).astype("float32")
...@@ -67,4 +67,4 @@ for pass_id in range(PASS_NUM): ...@@ -67,4 +67,4 @@ for pass_id in range(PASS_NUM):
fetch_list=[avg_cost]) fetch_list=[avg_cost])
out = np.array(outs[0]) out = np.array(outs[0])
print out print out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册