未验证 提交 64f18fa1 编写于 作者: C Charles-hit 提交者: GitHub

fix sum vjp and prim code gen cmake (#56542)

上级 f166ecc3
...@@ -78,7 +78,7 @@ std::vector<std::vector<ir::OpResult>> SumOp::Vjp( ...@@ -78,7 +78,7 @@ std::vector<std::vector<ir::OpResult>> SumOp::Vjp(
bool reduce_all = false; bool reduce_all = false;
std::vector<std::vector<Tensor>> tensor_res = primitive::sum_vjp( std::vector<std::vector<Tensor>> tensor_res = primitive::sum_vjp(
x, out_grad, axis, keepdim, reduce_all, stop_gradients); x, out_grad, axis, keepdim, reduce_all, stop_gradients);
std::vector<std::vector<ir::OpResult>> res(1, std::vector<ir::OpResult>(1)); std::vector<std::vector<ir::OpResult>> res(2, std::vector<ir::OpResult>(1));
if (tensor_res[0][0].defined()) { if (tensor_res[0][0].defined()) {
res[0][0] = res[0][0] =
std::static_pointer_cast<primitive::LazyTensor>(tensor_res[0][0].impl()) std::static_pointer_cast<primitive::LazyTensor>(tensor_res[0][0].impl())
......
set(eager_backend_files set(eager_backend_files
${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/backend/generated/generated_eager_backend.cc ${CMAKE_CURRENT_SOURCE_DIR}/generated/generated_eager_backend.cc)
)
if(WITH_PYTHON OR NOT ON_INFER) if(WITH_PYTHON OR NOT ON_INFER)
cc_library( cc_library(
primitive_backend_eager_experimental primitive_backend_eager_experimental
...@@ -8,9 +7,8 @@ if(WITH_PYTHON OR NOT ON_INFER) ...@@ -8,9 +7,8 @@ if(WITH_PYTHON OR NOT ON_INFER)
DEPS final_dygraph_function eager_utils phi) DEPS final_dygraph_function eager_utils phi)
endif() endif()
set(static_backend_files set(static_backend_files
${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/backend/generated/generated_static_backend.cc ${CMAKE_CURRENT_SOURCE_DIR}/generated/generated_static_backend.cc
${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/backend/manual/manual_static_backend.cc ${CMAKE_CURRENT_SOURCE_DIR}/manual/manual_static_backend.cc)
)
cc_library( cc_library(
primitive_backend_static_experimental primitive_backend_static_experimental
SRCS ${static_backend_files} SRCS ${static_backend_files}
......
set(fwd_path set(parsed_yaml_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/ops.parsed.yaml" "${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops")
) set(fwd_path ${parsed_yaml_path}/ops.parsed.yaml)
set(fwd_legacy_path set(fwd_legacy_path ${parsed_yaml_path}/legacy_ops.parsed.yaml)
"${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/legacy_ops.parsed.yaml" set(rev_path ${parsed_yaml_path}/backward_ops.parsed.yaml)
) set(rev_legacy_path ${parsed_yaml_path}/legacy_backward_ops.parsed.yaml)
set(rev_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/backward_ops.parsed.yaml"
)
set(rev_legacy_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/legacy_backward_ops.parsed.yaml"
)
set(prim_path "${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/primitive.yaml") set(prim_path "${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/primitive.yaml")
set(templates_dir set(templates_dir
"${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/codegen/templates/") "${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/codegen/templates/")
......
...@@ -126,14 +126,15 @@ class TestVjpPrim(unittest.TestCase): ...@@ -126,14 +126,15 @@ class TestVjpPrim(unittest.TestCase):
paddle.fluid.core._set_prim_backward_enabled(True) paddle.fluid.core._set_prim_backward_enabled(True)
dout = newir_program.block().ops[-2].result(0) dout = newir_program.block().ops[-2].result(0)
out_grads = [[dout]] out_grads = [[dout]]
stop_gradients = [[False]] stop_gradients = [[False], [True]]
sum_op = newir_program.block().ops[-1] sum_op = newir_program.block().ops[-1]
with paddle.ir.core.program_guard(newir_program): with paddle.ir.core.program_guard(newir_program):
grad_outs = call_vjp(sum_op, out_grads, stop_gradients) grad_outs = call_vjp(sum_op, out_grads, stop_gradients)
expand_op = newir_program.block().ops[-1] expand_op = newir_program.block().ops[-1]
self.assertEqual(len(grad_outs), 1) self.assertEqual(len(grad_outs), 2)
self.assertEqual(len(newir_program.block().ops), 8) self.assertEqual(len(newir_program.block().ops), 8)
self.assertEqual(expand_op.result(0), grad_outs[0][0]) self.assertEqual(expand_op.result(0), grad_outs[0][0])
self.assertEqual(grad_outs[1][0], None)
all_op_names = [ all_op_names = [
"pd.full", "pd.full",
"pd.full", "pd.full",
...@@ -152,14 +153,15 @@ class TestVjpPrim(unittest.TestCase): ...@@ -152,14 +153,15 @@ class TestVjpPrim(unittest.TestCase):
paddle.fluid.core._set_prim_backward_enabled(False) paddle.fluid.core._set_prim_backward_enabled(False)
dout = newir_program.block().ops[-2].result(0) dout = newir_program.block().ops[-2].result(0)
out_grads = [[dout]] out_grads = [[dout]]
stop_gradients = [[False]] stop_gradients = [[False], [True]]
sum_op = newir_program.block().ops[-1] sum_op = newir_program.block().ops[-1]
with paddle.ir.core.program_guard(newir_program): with paddle.ir.core.program_guard(newir_program):
grad_outs = call_vjp(sum_op, out_grads, stop_gradients) grad_outs = call_vjp(sum_op, out_grads, stop_gradients)
self.assertEqual(len(grad_outs), 1) self.assertEqual(len(grad_outs), 2)
self.assertEqual( self.assertEqual(
grad_outs[0][0].get_defining_op().name(), "pd.sum_grad" grad_outs[0][0].get_defining_op().name(), "pd.sum_grad"
) )
self.assertEqual(grad_outs[1][0], None)
self.assertEqual(len(newir_program.block().ops), 6) self.assertEqual(len(newir_program.block().ops), 6)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册