未验证 提交 aa887717 编写于 作者: X xiaoguoguo626807 提交者: GitHub

【prim】modify assign api setOutput in by_pass (#53417)

* modify concat_grad add sum comp rule

* cast and by_pass modify

* only modify by_pass

* modify by_pass
上级 0ded3f04
...@@ -116,6 +116,7 @@ void cast_grad(const Tensor& out_grad, DataType dtype, Tensor* x_grad) { ...@@ -116,6 +116,7 @@ void cast_grad(const Tensor& out_grad, DataType dtype, Tensor* x_grad) {
set_output<T>(res, x_grad); set_output<T>(res, x_grad);
} }
} }
template <typename T> template <typename T>
void gather_grad(const Tensor& x, void gather_grad(const Tensor& x,
const Tensor& index, const Tensor& index,
......
...@@ -53,20 +53,20 @@ void set_output<DescTensor>(const paddle::Tensor& x_tmp, paddle::Tensor* x) { ...@@ -53,20 +53,20 @@ void set_output<DescTensor>(const paddle::Tensor& x_tmp, paddle::Tensor* x) {
} }
template <> template <>
void by_pass<DescTensor>(const paddle::Tensor& x, paddle::Tensor* out) { void by_pass<DescTensor>(const paddle::Tensor& x, paddle::Tensor* real_out) {
Tensor new_out =
empty<DescTensor>({}, phi::DataType::FLOAT32, paddle::Place());
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock(); framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
framework::OpDesc* op = block->AppendOp(); framework::OpDesc* op = block->AppendOp();
op->SetType("assign"); op->SetType("assign");
op->SetInput("X", op->SetInput("X",
{std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()}); {std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()});
auto out = empty<DescTensor>({}, x.dtype(), paddle::Place());
op->SetOutput( op->SetOutput(
"Out", {std::static_pointer_cast<prim::DescTensor>(out->impl())->Name()}); "Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
op->CheckAttrs(); op->CheckAttrs();
op->InferVarType(block); op->InferVarType(block);
op->InferShape(*block); op->InferShape(*block);
set_output<DescTensor>(new_out, out);
set_output<DescTensor>(out, real_out);
} }
} // namespace prim } // namespace prim
......
...@@ -462,16 +462,17 @@ class TestSoftmaxAPI(unittest.TestCase): ...@@ -462,16 +462,17 @@ class TestSoftmaxAPI(unittest.TestCase):
self.softmax = F.softmax self.softmax = F.softmax
def test_static_check(self): def test_static_check(self):
with paddle.static.program_guard(paddle.static.Program()): with paddle.fluid.framework._static_guard():
x = paddle.static.data('X', self.x_np.shape, 'float32') with paddle.static.program_guard(paddle.static.Program()):
out1 = self.softmax(x) x = paddle.static.data('X', self.x_np.shape, 'float32')
m = paddle.nn.Softmax() out1 = self.softmax(x)
out2 = m(x) m = paddle.nn.Softmax()
exe = paddle.static.Executor(self.place) out2 = m(x)
res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2]) exe = paddle.static.Executor(self.place)
out_ref = ref_softmax(self.x_np, axis=-1, dtype=None) res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2])
for r in res: out_ref = ref_softmax(self.x_np, axis=-1, dtype=None)
np.testing.assert_allclose(out_ref, r, rtol=1e-05) for r in res:
np.testing.assert_allclose(out_ref, r, rtol=1e-05)
def test_dygraph_check(self): def test_dygraph_check(self):
paddle.disable_static(self.place) paddle.disable_static(self.place)
...@@ -505,19 +506,20 @@ class TestSoftmaxAPI(unittest.TestCase): ...@@ -505,19 +506,20 @@ class TestSoftmaxAPI(unittest.TestCase):
paddle.enable_static() paddle.enable_static()
def test_error(self): def test_error(self):
with paddle.static.program_guard(paddle.static.Program()): with paddle.fluid.framework._static_guard():
# The input type must be Variable. with paddle.static.program_guard(paddle.static.Program()):
self.assertRaises(TypeError, self.softmax, 1) # The input type must be Variable.
# The input dtype must be float16, float32, float64. self.assertRaises(TypeError, self.softmax, 1)
x_int32 = paddle.static.data( # The input dtype must be float16, float32, float64.
name='x_int32', shape=[2, 3], dtype='int32' x_int32 = paddle.static.data(
) name='x_int32', shape=[2, 3], dtype='int32'
self.assertRaises(TypeError, self.softmax, x_int32) )
# support the input dtype is float16 self.assertRaises(TypeError, self.softmax, x_int32)
x_fp16 = paddle.static.data( # support the input dtype is float16
name='x_fp16', shape=[2, 3], dtype='float16' x_fp16 = paddle.static.data(
) name='x_fp16', shape=[2, 3], dtype='float16'
self.softmax(x_fp16) )
self.softmax(x_fp16)
class TestSoftmaxAPI_ZeroDim(unittest.TestCase): class TestSoftmaxAPI_ZeroDim(unittest.TestCase):
...@@ -538,23 +540,24 @@ class TestSoftmaxAPI_ZeroDim(unittest.TestCase): ...@@ -538,23 +540,24 @@ class TestSoftmaxAPI_ZeroDim(unittest.TestCase):
paddle.enable_static() paddle.enable_static()
def test_static(self): def test_static(self):
main_prog = fluid.Program() with paddle.fluid.framework._static_guard():
with fluid.program_guard(main_prog, fluid.Program()): main_prog = fluid.Program()
x = paddle.rand([]) with fluid.program_guard(main_prog, fluid.Program()):
x.stop_gradient = False x = paddle.rand([])
out = paddle.nn.functional.softmax(x) x.stop_gradient = False
fluid.backward.append_backward(out) out = paddle.nn.functional.softmax(x)
fluid.backward.append_backward(out)
# Test compile shape
self.assertEqual(x.shape, ()) # Test compile shape
self.assertEqual(out.shape, ()) self.assertEqual(x.shape, ())
self.assertEqual(out.shape, ())
exe = fluid.Executor()
result = exe.run(main_prog, fetch_list=[x, out]) exe = fluid.Executor()
result = exe.run(main_prog, fetch_list=[x, out])
# Test runtime shape
self.assertEqual(result[0].shape, ()) # Test runtime shape
self.assertEqual(result[1].shape, ()) self.assertEqual(result[0].shape, ())
self.assertEqual(result[1].shape, ())
class TestSoftmaxInplaceAPI(TestSoftmaxAPI): class TestSoftmaxInplaceAPI(TestSoftmaxAPI):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册