未验证 提交 669786bf 编写于 作者: Q QI JUN 提交者: GitHub

refine square_error_cost layer (#5216)

* reimplement pow operator

* add pow_grad operator

* fix code style

* fix build error

* fix op_test bug

* revert pow operator

* add FIXME comment
上级 afd1e844
......@@ -547,6 +547,7 @@ struct ELUGradFunctor : public BaseActivationFunctor<T> {
}
};
// FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5198
template <typename T>
struct PowFunctor : public BaseActivationFunctor<T> {
float factor;
......
......@@ -225,10 +225,7 @@ def square_error_cost(input, label, **kwargs):
square_out = helper.create_tmp_variable(dtype=input.data_type)
helper.append_op(
type='pow',
inputs={'X': [minus_out]},
outputs={'Y': [square_out]},
attrs={'factor': 2.0})
type='square', inputs={'X': [minus_out]}, outputs={'Y': [square_out]})
return square_out
......
......@@ -281,7 +281,8 @@ class OpTest(unittest.TestCase):
type(sub_out))
for sub_out_name, expect in sub_out:
idx = find_actual(sub_out_name, fetch_list)
actual_t = np.array(outs[idx])
actual = outs[idx]
actual_t = np.array(actual)
expect_t = expect[0] \
if isinstance(expect, tuple) else expect
self.assertTrue(
......@@ -291,11 +292,12 @@ class OpTest(unittest.TestCase):
str(place))
if isinstance(expect, tuple):
self.assertListEqual(
actual_t.lod(), expect[1], "Output (" + sub_out_name
+ ") has different lod at " + str(place))
actual.lod(), expect[1], "Output (" + sub_out_name +
") has different lod at " + str(place))
else:
idx = find_actual(out_name, fetch_list)
actual_t = outs[idx]
actual = outs[idx]
actual_t = np.array(actual)
expect = self.outputs[out_name]
expect_t = expect[0] if isinstance(expect, tuple) else expect
self.assertTrue(
......@@ -303,7 +305,7 @@ class OpTest(unittest.TestCase):
actual_t, expect_t, atol=atol),
"Output (" + out_name + ") has diff at " + str(place))
if isinstance(expect, tuple):
self.assertListEqual(actual_t.lod(), expect[1],
self.assertListEqual(actual.lod(), expect[1],
"Output (" + out_name +
") has different lod at " + str(place))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册