未验证 提交 aa887717 编写于 作者: X xiaoguoguo626807 提交者: GitHub

【prim】modify assign api setOutput in by_pass (#53417)

* modify concat_grad add sum comp rule

* cast and by_pass modify

* only modify by_pass

* modify by_pass
上级 0ded3f04
...@@ -116,6 +116,7 @@ void cast_grad(const Tensor& out_grad, DataType dtype, Tensor* x_grad) { ...@@ -116,6 +116,7 @@ void cast_grad(const Tensor& out_grad, DataType dtype, Tensor* x_grad) {
set_output<T>(res, x_grad); set_output<T>(res, x_grad);
} }
} }
template <typename T> template <typename T>
void gather_grad(const Tensor& x, void gather_grad(const Tensor& x,
const Tensor& index, const Tensor& index,
......
...@@ -53,20 +53,20 @@ void set_output<DescTensor>(const paddle::Tensor& x_tmp, paddle::Tensor* x) { ...@@ -53,20 +53,20 @@ void set_output<DescTensor>(const paddle::Tensor& x_tmp, paddle::Tensor* x) {
} }
template <> template <>
void by_pass<DescTensor>(const paddle::Tensor& x, paddle::Tensor* out) { void by_pass<DescTensor>(const paddle::Tensor& x, paddle::Tensor* real_out) {
Tensor new_out =
empty<DescTensor>({}, phi::DataType::FLOAT32, paddle::Place());
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock(); framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
framework::OpDesc* op = block->AppendOp(); framework::OpDesc* op = block->AppendOp();
op->SetType("assign"); op->SetType("assign");
op->SetInput("X", op->SetInput("X",
{std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()}); {std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()});
auto out = empty<DescTensor>({}, x.dtype(), paddle::Place());
op->SetOutput( op->SetOutput(
"Out", {std::static_pointer_cast<prim::DescTensor>(out->impl())->Name()}); "Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
op->CheckAttrs(); op->CheckAttrs();
op->InferVarType(block); op->InferVarType(block);
op->InferShape(*block); op->InferShape(*block);
set_output<DescTensor>(new_out, out);
set_output<DescTensor>(out, real_out);
} }
} // namespace prim } // namespace prim
......
...@@ -462,6 +462,7 @@ class TestSoftmaxAPI(unittest.TestCase): ...@@ -462,6 +462,7 @@ class TestSoftmaxAPI(unittest.TestCase):
self.softmax = F.softmax self.softmax = F.softmax
def test_static_check(self): def test_static_check(self):
with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('X', self.x_np.shape, 'float32') x = paddle.static.data('X', self.x_np.shape, 'float32')
out1 = self.softmax(x) out1 = self.softmax(x)
...@@ -505,6 +506,7 @@ class TestSoftmaxAPI(unittest.TestCase): ...@@ -505,6 +506,7 @@ class TestSoftmaxAPI(unittest.TestCase):
paddle.enable_static() paddle.enable_static()
def test_error(self): def test_error(self):
with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable. # The input type must be Variable.
self.assertRaises(TypeError, self.softmax, 1) self.assertRaises(TypeError, self.softmax, 1)
...@@ -538,6 +540,7 @@ class TestSoftmaxAPI_ZeroDim(unittest.TestCase): ...@@ -538,6 +540,7 @@ class TestSoftmaxAPI_ZeroDim(unittest.TestCase):
paddle.enable_static() paddle.enable_static()
def test_static(self): def test_static(self):
with paddle.fluid.framework._static_guard():
main_prog = fluid.Program() main_prog = fluid.Program()
with fluid.program_guard(main_prog, fluid.Program()): with fluid.program_guard(main_prog, fluid.Program()):
x = paddle.rand([]) x = paddle.rand([])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册