未验证 提交 64f18fa1 编写于 作者: C Charles-hit 提交者: GitHub

fix sum vjp and prim code gen cmake (#56542)

上级 f166ecc3
......@@ -78,7 +78,7 @@ std::vector<std::vector<ir::OpResult>> SumOp::Vjp(
bool reduce_all = false;
std::vector<std::vector<Tensor>> tensor_res = primitive::sum_vjp(
x, out_grad, axis, keepdim, reduce_all, stop_gradients);
std::vector<std::vector<ir::OpResult>> res(1, std::vector<ir::OpResult>(1));
std::vector<std::vector<ir::OpResult>> res(2, std::vector<ir::OpResult>(1));
if (tensor_res[0][0].defined()) {
res[0][0] =
std::static_pointer_cast<primitive::LazyTensor>(tensor_res[0][0].impl())
......
set(eager_backend_files
${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/backend/generated/generated_eager_backend.cc
)
${CMAKE_CURRENT_SOURCE_DIR}/generated/generated_eager_backend.cc)
if(WITH_PYTHON OR NOT ON_INFER)
cc_library(
primitive_backend_eager_experimental
......@@ -8,9 +7,8 @@ if(WITH_PYTHON OR NOT ON_INFER)
DEPS final_dygraph_function eager_utils phi)
endif()
set(static_backend_files
${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/backend/generated/generated_static_backend.cc
${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/backend/manual/manual_static_backend.cc
)
${CMAKE_CURRENT_SOURCE_DIR}/generated/generated_static_backend.cc
${CMAKE_CURRENT_SOURCE_DIR}/manual/manual_static_backend.cc)
cc_library(
primitive_backend_static_experimental
SRCS ${static_backend_files}
......
set(fwd_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/ops.parsed.yaml"
)
set(fwd_legacy_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/legacy_ops.parsed.yaml"
)
set(rev_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/backward_ops.parsed.yaml"
)
set(rev_legacy_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/legacy_backward_ops.parsed.yaml"
)
set(parsed_yaml_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops")
set(fwd_path ${parsed_yaml_path}/ops.parsed.yaml)
set(fwd_legacy_path ${parsed_yaml_path}/legacy_ops.parsed.yaml)
set(rev_path ${parsed_yaml_path}/backward_ops.parsed.yaml)
set(rev_legacy_path ${parsed_yaml_path}/legacy_backward_ops.parsed.yaml)
set(prim_path "${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/primitive.yaml")
set(templates_dir
"${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/codegen/templates/")
......
......@@ -126,14 +126,15 @@ class TestVjpPrim(unittest.TestCase):
paddle.fluid.core._set_prim_backward_enabled(True)
dout = newir_program.block().ops[-2].result(0)
out_grads = [[dout]]
stop_gradients = [[False]]
stop_gradients = [[False], [True]]
sum_op = newir_program.block().ops[-1]
with paddle.ir.core.program_guard(newir_program):
grad_outs = call_vjp(sum_op, out_grads, stop_gradients)
expand_op = newir_program.block().ops[-1]
self.assertEqual(len(grad_outs), 1)
self.assertEqual(len(grad_outs), 2)
self.assertEqual(len(newir_program.block().ops), 8)
self.assertEqual(expand_op.result(0), grad_outs[0][0])
self.assertEqual(grad_outs[1][0], None)
all_op_names = [
"pd.full",
"pd.full",
......@@ -152,14 +153,15 @@ class TestVjpPrim(unittest.TestCase):
paddle.fluid.core._set_prim_backward_enabled(False)
dout = newir_program.block().ops[-2].result(0)
out_grads = [[dout]]
stop_gradients = [[False]]
stop_gradients = [[False], [True]]
sum_op = newir_program.block().ops[-1]
with paddle.ir.core.program_guard(newir_program):
grad_outs = call_vjp(sum_op, out_grads, stop_gradients)
self.assertEqual(len(grad_outs), 1)
self.assertEqual(len(grad_outs), 2)
self.assertEqual(
grad_outs[0][0].get_defining_op().name(), "pd.sum_grad"
)
self.assertEqual(grad_outs[1][0], None)
self.assertEqual(len(newir_program.block().ops), 6)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册