未验证 提交 aa887717 编写于 作者: X xiaoguoguo626807 提交者: GitHub

【prim】modify assign api setOutput in by_pass (#53417)

* modify concat_grad add sum comp rule

* cast and by_pass modify

* only modify by_pass

* modify by_pass
上级 0ded3f04
......@@ -116,6 +116,7 @@ void cast_grad(const Tensor& out_grad, DataType dtype, Tensor* x_grad) {
set_output<T>(res, x_grad);
}
}
template <typename T>
void gather_grad(const Tensor& x,
const Tensor& index,
......
......@@ -53,20 +53,20 @@ void set_output<DescTensor>(const paddle::Tensor& x_tmp, paddle::Tensor* x) {
}
template <>
void by_pass<DescTensor>(const paddle::Tensor& x, paddle::Tensor* out) {
Tensor new_out =
empty<DescTensor>({}, phi::DataType::FLOAT32, paddle::Place());
void by_pass<DescTensor>(const paddle::Tensor& x, paddle::Tensor* real_out) {
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
framework::OpDesc* op = block->AppendOp();
op->SetType("assign");
op->SetInput("X",
{std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()});
auto out = empty<DescTensor>({}, x.dtype(), paddle::Place());
op->SetOutput(
"Out", {std::static_pointer_cast<prim::DescTensor>(out->impl())->Name()});
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
op->CheckAttrs();
op->InferVarType(block);
op->InferShape(*block);
set_output<DescTensor>(new_out, out);
set_output<DescTensor>(out, real_out);
}
} // namespace prim
......
......@@ -462,16 +462,17 @@ class TestSoftmaxAPI(unittest.TestCase):
self.softmax = F.softmax
def test_static_check(self):
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('X', self.x_np.shape, 'float32')
out1 = self.softmax(x)
m = paddle.nn.Softmax()
out2 = m(x)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2])
out_ref = ref_softmax(self.x_np, axis=-1, dtype=None)
for r in res:
np.testing.assert_allclose(out_ref, r, rtol=1e-05)
with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('X', self.x_np.shape, 'float32')
out1 = self.softmax(x)
m = paddle.nn.Softmax()
out2 = m(x)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2])
out_ref = ref_softmax(self.x_np, axis=-1, dtype=None)
for r in res:
np.testing.assert_allclose(out_ref, r, rtol=1e-05)
def test_dygraph_check(self):
paddle.disable_static(self.place)
......@@ -505,19 +506,20 @@ class TestSoftmaxAPI(unittest.TestCase):
paddle.enable_static()
def test_error(self):
with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable.
self.assertRaises(TypeError, self.softmax, 1)
# The input dtype must be float16, float32, float64.
x_int32 = paddle.static.data(
name='x_int32', shape=[2, 3], dtype='int32'
)
self.assertRaises(TypeError, self.softmax, x_int32)
# support the input dtype is float16
x_fp16 = paddle.static.data(
name='x_fp16', shape=[2, 3], dtype='float16'
)
self.softmax(x_fp16)
with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable.
self.assertRaises(TypeError, self.softmax, 1)
# The input dtype must be float16, float32, float64.
x_int32 = paddle.static.data(
name='x_int32', shape=[2, 3], dtype='int32'
)
self.assertRaises(TypeError, self.softmax, x_int32)
# support the input dtype is float16
x_fp16 = paddle.static.data(
name='x_fp16', shape=[2, 3], dtype='float16'
)
self.softmax(x_fp16)
class TestSoftmaxAPI_ZeroDim(unittest.TestCase):
......@@ -538,23 +540,24 @@ class TestSoftmaxAPI_ZeroDim(unittest.TestCase):
paddle.enable_static()
def test_static(self):
main_prog = fluid.Program()
with fluid.program_guard(main_prog, fluid.Program()):
x = paddle.rand([])
x.stop_gradient = False
out = paddle.nn.functional.softmax(x)
fluid.backward.append_backward(out)
# Test compile shape
self.assertEqual(x.shape, ())
self.assertEqual(out.shape, ())
exe = fluid.Executor()
result = exe.run(main_prog, fetch_list=[x, out])
# Test runtime shape
self.assertEqual(result[0].shape, ())
self.assertEqual(result[1].shape, ())
with paddle.fluid.framework._static_guard():
main_prog = fluid.Program()
with fluid.program_guard(main_prog, fluid.Program()):
x = paddle.rand([])
x.stop_gradient = False
out = paddle.nn.functional.softmax(x)
fluid.backward.append_backward(out)
# Test compile shape
self.assertEqual(x.shape, ())
self.assertEqual(out.shape, ())
exe = fluid.Executor()
result = exe.run(main_prog, fetch_list=[x, out])
# Test runtime shape
self.assertEqual(result[0].shape, ())
self.assertEqual(result[1].shape, ())
class TestSoftmaxInplaceAPI(TestSoftmaxAPI):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册