未验证 提交 aa887717 编写于 作者: X xiaoguoguo626807 提交者: GitHub

【prim】modify assign api setOutput in by_pass (#53417)

* modify concat_grad add sum comp rule

* cast and by_pass modify

* only modify by_pass

* modify by_pass
上级 0ded3f04
......@@ -116,6 +116,7 @@ void cast_grad(const Tensor& out_grad, DataType dtype, Tensor* x_grad) {
set_output<T>(res, x_grad);
}
}
template <typename T>
void gather_grad(const Tensor& x,
const Tensor& index,
......
......@@ -53,20 +53,20 @@ void set_output<DescTensor>(const paddle::Tensor& x_tmp, paddle::Tensor* x) {
}
template <>
void by_pass<DescTensor>(const paddle::Tensor& x, paddle::Tensor* out) {
Tensor new_out =
empty<DescTensor>({}, phi::DataType::FLOAT32, paddle::Place());
void by_pass<DescTensor>(const paddle::Tensor& x, paddle::Tensor* real_out) {
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
framework::OpDesc* op = block->AppendOp();
op->SetType("assign");
op->SetInput("X",
{std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()});
auto out = empty<DescTensor>({}, x.dtype(), paddle::Place());
op->SetOutput(
"Out", {std::static_pointer_cast<prim::DescTensor>(out->impl())->Name()});
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
op->CheckAttrs();
op->InferVarType(block);
op->InferShape(*block);
set_output<DescTensor>(new_out, out);
set_output<DescTensor>(out, real_out);
}
} // namespace prim
......
......@@ -462,6 +462,7 @@ class TestSoftmaxAPI(unittest.TestCase):
self.softmax = F.softmax
def test_static_check(self):
with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('X', self.x_np.shape, 'float32')
out1 = self.softmax(x)
......@@ -505,6 +506,7 @@ class TestSoftmaxAPI(unittest.TestCase):
paddle.enable_static()
def test_error(self):
with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable.
self.assertRaises(TypeError, self.softmax, 1)
......@@ -538,6 +540,7 @@ class TestSoftmaxAPI_ZeroDim(unittest.TestCase):
paddle.enable_static()
def test_static(self):
with paddle.fluid.framework._static_guard():
main_prog = fluid.Program()
with fluid.program_guard(main_prog, fluid.Program()):
x = paddle.rand([])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册