未验证 提交 d8643cb6 编写于 作者: X Xiaoxu Chen 提交者: GitHub

【PRIM】Support use operator's output metadata info in constructing static...

【PRIM】Support use operator's output metadata info  in constructing static backward composite (#50043)

* [prim] support custom target_gradients

* support infershape after append one gradop

* [prim] add simple net test

* fix test_loop segment fault bug

* [prim] fix infer shape segment fault bug when output of grad_op_desc is empty
上级 3e656414
...@@ -110,6 +110,7 @@ Tensor unsqueeze<DescTensor>(const Tensor& x, const IntArray& axis) { ...@@ -110,6 +110,7 @@ Tensor unsqueeze<DescTensor>(const Tensor& x, const IntArray& axis) {
op->SetAttr("axes", new_shape); op->SetAttr("axes", new_shape);
op->CheckAttrs(); op->CheckAttrs();
op->InferVarType(block); op->InferVarType(block);
op->InferShape(*block);
return out; return out;
} }
...@@ -209,7 +210,7 @@ Tensor sum<DescTensor>(const Tensor& x, ...@@ -209,7 +210,7 @@ Tensor sum<DescTensor>(const Tensor& x,
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()}); "Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
op->CheckAttrs(); op->CheckAttrs();
op->InferVarType(block); op->InferVarType(block);
// TODO(jiabin, cxxly): This may have runtime shape skip infershape for now. op->InferShape(*block);
return out; return out;
} }
...@@ -232,7 +233,7 @@ Tensor reshape<DescTensor>(const Tensor& x, const IntArray& shape) { ...@@ -232,7 +233,7 @@ Tensor reshape<DescTensor>(const Tensor& x, const IntArray& shape) {
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()}); "Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
op->CheckAttrs(); op->CheckAttrs();
op->InferVarType(block); op->InferVarType(block);
// TODO(jiabin, cxxly): This may have runtime shape skip infershape for now. op->InferShape(*block);
return out; return out;
} }
......
...@@ -232,8 +232,8 @@ void multiply_grad(const Tensor& x, ...@@ -232,8 +232,8 @@ void multiply_grad(const Tensor& x,
Tensor* y_grad) { Tensor* y_grad) {
if (x_grad) { if (x_grad) {
auto x_grad_unreduce = multiply<T>(out_grad, y); auto x_grad_unreduce = multiply<T>(out_grad, y);
if (x.dims() != y.dims()) { if (x_grad_unreduce.dims() != x.dims()) {
auto axes = get_reduce_dims(x.dims(), y.dims()); auto axes = get_reduce_dims_from_out(x_grad_unreduce.dims(), x.dims());
if (!axes.size()) { if (!axes.size()) {
set_output<T>(x_grad_unreduce, x_grad); set_output<T>(x_grad_unreduce, x_grad);
} else { } else {
...@@ -252,8 +252,8 @@ void multiply_grad(const Tensor& x, ...@@ -252,8 +252,8 @@ void multiply_grad(const Tensor& x,
} }
if (y_grad) { if (y_grad) {
auto y_grad_unreduce = multiply<T>(out_grad, x); auto y_grad_unreduce = multiply<T>(out_grad, x);
if (y.dims() != x.dims()) { if (y_grad_unreduce.dims() != y.dims()) {
auto axes = get_reduce_dims(y.dims(), x.dims()); auto axes = get_reduce_dims_from_out(y_grad_unreduce.dims(), y.dims());
if (!axes.size()) { if (!axes.size()) {
set_output<T>(y_grad_unreduce, y_grad); set_output<T>(y_grad_unreduce, y_grad);
} else { } else {
......
...@@ -318,6 +318,7 @@ class CompositeGradOpMakerBase { ...@@ -318,6 +318,7 @@ class CompositeGradOpMakerBase {
grad_var_name = framework::kEmptyVarName; grad_var_name = framework::kEmptyVarName;
if (drop_empty_grad) return nullptr; if (drop_empty_grad) return nullptr;
} }
if (original_block_->HasVar(grad_var_name)) { if (original_block_->HasVar(grad_var_name)) {
// Copy Var from original block to active block, or create a new one. // Copy Var from original block to active block, or create a new one.
CopyVarFromOrig(grad_var_name); CopyVarFromOrig(grad_var_name);
...@@ -333,6 +334,12 @@ class CompositeGradOpMakerBase { ...@@ -333,6 +334,12 @@ class CompositeGradOpMakerBase {
auto grad_var_name = framework::GradVarName(var_name); auto grad_var_name = framework::GradVarName(var_name);
(*this->grad_to_var_)[grad_var_name] = var_name; (*this->grad_to_var_)[grad_var_name] = var_name;
VLOG(8) << "Valid gradients: " << grad_var_name; VLOG(8) << "Valid gradients: " << grad_var_name;
auto target_grad = StaticCompositeContext::Instance().GetTargetGradName();
if (target_grad.find(grad_var_name) != target_grad.end()) {
grad_var_name = target_grad.at(grad_var_name);
}
if (original_block_->HasVar(grad_var_name)) { if (original_block_->HasVar(grad_var_name)) {
// Copy Var from original block to active block, or create a new one. // Copy Var from original block to active block, or create a new one.
CopyVarFromOrig(grad_var_name); CopyVarFromOrig(grad_var_name);
...@@ -421,7 +428,11 @@ class CompositeGradOpMakerBase { ...@@ -421,7 +428,11 @@ class CompositeGradOpMakerBase {
return g_name; return g_name;
}); });
std::vector<framework::VarDesc*> grad_out; std::vector<framework::VarDesc*> grad_out;
for (const auto& name : ret_val) { for (auto name : ret_val) {
auto target_grad = StaticCompositeContext::Instance().GetTargetGradName();
if (target_grad.find(name) != target_grad.end()) {
name = target_grad.at(name);
}
// TODO(jiabin): Will this cause fill zeros error? // TODO(jiabin): Will this cause fill zeros error?
if (original_block_->HasVar(name)) { if (original_block_->HasVar(name)) {
// Copy Var from original block to active block, or create a new one. // Copy Var from original block to active block, or create a new one.
......
...@@ -69,12 +69,21 @@ class StaticCompositeContext { ...@@ -69,12 +69,21 @@ class StaticCompositeContext {
enable_bwd_prim_ = enable_prim; enable_bwd_prim_ = enable_prim;
} }
void SetTargetGradName(const std::map<std::string, std::string>& m) {
target_grad_name_ = m;
}
std::map<std::string, std::string> GetTargetGradName() {
return target_grad_name_;
}
private: private:
StaticCompositeContext() StaticCompositeContext()
: current_block_desc_(nullptr), generator_(new UniqueNameGenerator()) {} : current_block_desc_(nullptr), generator_(new UniqueNameGenerator()) {}
framework::BlockDesc* current_block_desc_; framework::BlockDesc* current_block_desc_;
std::unique_ptr<UniqueNameGenerator> generator_; std::unique_ptr<UniqueNameGenerator> generator_;
std::map<std::string, std::string> target_grad_name_;
static thread_local bool enable_bwd_prim_; static thread_local bool enable_bwd_prim_;
static thread_local bool enable_fwd_prim_; static thread_local bool enable_fwd_prim_;
static StaticCompositeContext* static_composite_context_; static StaticCompositeContext* static_composite_context_;
......
...@@ -38,5 +38,11 @@ void PrimCommonUtils::SetFwdPrimEnabled(bool enable_prim) { ...@@ -38,5 +38,11 @@ void PrimCommonUtils::SetFwdPrimEnabled(bool enable_prim) {
void PrimCommonUtils::SetAllPrimEnabled(bool enable_prim) { void PrimCommonUtils::SetAllPrimEnabled(bool enable_prim) {
return StaticCompositeContext::Instance().SetAllPrimEnabled(enable_prim); return StaticCompositeContext::Instance().SetAllPrimEnabled(enable_prim);
} }
void PrimCommonUtils::SetTargetGradName(
const std::map<std::string, std::string>& m) {
StaticCompositeContext::Instance().SetTargetGradName(m);
}
} // namespace prim } // namespace prim
} // namespace paddle } // namespace paddle
...@@ -14,6 +14,9 @@ ...@@ -14,6 +14,9 @@
#pragma once #pragma once
#include <map>
#include <string>
namespace paddle { namespace paddle {
namespace prim { namespace prim {
class PrimCommonUtils { class PrimCommonUtils {
...@@ -23,6 +26,7 @@ class PrimCommonUtils { ...@@ -23,6 +26,7 @@ class PrimCommonUtils {
static bool IsFwdPrimEnabled(); static bool IsFwdPrimEnabled();
static void SetFwdPrimEnabled(bool enabled); static void SetFwdPrimEnabled(bool enabled);
static void SetAllPrimEnabled(bool enabled); static void SetAllPrimEnabled(bool enabled);
static void SetTargetGradName(const std::map<std::string, std::string>& m);
}; };
} // namespace prim } // namespace prim
} // namespace paddle } // namespace paddle
...@@ -673,6 +673,8 @@ PYBIND11_MODULE(libpaddle, m) { ...@@ -673,6 +673,8 @@ PYBIND11_MODULE(libpaddle, m) {
&paddle::prim::PrimCommonUtils::IsFwdPrimEnabled); &paddle::prim::PrimCommonUtils::IsFwdPrimEnabled);
m.def("__set_all_prim_enabled", m.def("__set_all_prim_enabled",
&paddle::prim::PrimCommonUtils::SetAllPrimEnabled); &paddle::prim::PrimCommonUtils::SetAllPrimEnabled);
m.def("_set_prim_target_grad_name",
&paddle::prim::PrimCommonUtils::SetTargetGradName);
m.def("set_num_threads", &platform::SetNumThreads); m.def("set_num_threads", &platform::SetNumThreads);
m.def("disable_signal_handler", &DisableSignalHandler); m.def("disable_signal_handler", &DisableSignalHandler);
......
...@@ -1337,6 +1337,18 @@ def _append_backward_ops_( ...@@ -1337,6 +1337,18 @@ def _append_backward_ops_(
rename_var_map = {} rename_var_map = {}
assert isinstance(rename_var_map, dict) assert isinstance(rename_var_map, dict)
if core._is_bwd_prim_enabled():
composite_block = program.clone().current_block()
# Infer shape for operators whose output haven't been created.
for op in composite_block.ops:
if not all(
tuple(
composite_block._find_var_recursive(arg)
for arg in op.output_arg_names
)
):
infershape_for_composite(composite_block, op.desc)
# add grad_op_desc by reversed ops # add grad_op_desc by reversed ops
for op in reversed(ops): for op in reversed(ops):
grad_sub_block_list = [] grad_sub_block_list = []
...@@ -1365,11 +1377,42 @@ def _append_backward_ops_( ...@@ -1365,11 +1377,42 @@ def _append_backward_ops_(
program._rollback() program._rollback()
grad_sub_block_list.append(grad_sub_block.desc) grad_sub_block_list.append(grad_sub_block.desc)
# In primitive mode, raw phi GradOp will be split into multiple small
# primitive operators, and the split rules are defined in c++ level,
# see detials: paddle/fluid/prim/api/manual/backward/composite_backward_api.h
# It means that the output's shape and dtype of previous operators which
# maybe used as the input of next operators must be known. Therefore,
# we infer shape and dtype in a sandbox block(named composite_block) for
# used in c++ level.
# For example:
# forward:
# z = multiply(x, y) //maybe broadcast in kernel
# bcckward:
# x_grad_unreduce = z_grad * y // maybe unreduce
# reduced_axes = get_reduced_axes(x_grad.shape, x.shape) // need known shape
# x_grad = reduce_sum(x_grad_unreduce)
grad_op_desc = []
op_grad_to_var = {}
if core._is_bwd_prim_enabled():
def find_op_index(block_desc, cur_op_desc):
for idx in range(block_desc.op_size()):
if cur_op_desc == block_desc.op(idx):
return idx
return -1
# Getting op's corresponding grad_op grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
grad_op_desc, op_grad_to_var = core.get_grad_op_desc( composite_block.desc.op(find_op_index(block.desc, op.desc)),
op.desc, no_grad_dict[block.idx], grad_sub_block_list no_grad_dict[composite_block.idx],
) grad_sub_block_list,
)
for desc in grad_op_desc:
infershape_for_composite(composite_block, desc)
else:
# Getting op's corresponding grad_op
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, no_grad_dict[block.idx], grad_sub_block_list
)
# record the mapping between fwd and bwd # record the mapping between fwd and bwd
if grad_op_id_to_fwd_op is not None: if grad_op_id_to_fwd_op is not None:
...@@ -1655,7 +1698,43 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): ...@@ -1655,7 +1698,43 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
block.desc._remove_op(op_idx, op_idx + 1) block.desc._remove_op(op_idx, op_idx + 1)
def _rename_grad_(block, start_op_idx, grad_to_var, target_grad_map): def infershape_for_composite(block, grad_op_desc):
# pruning empty output
if len(grad_op_desc.output_arg_names()) == 0:
return
# append op to block
op_desc = block.desc.append_op()
op_desc.copy_from(grad_op_desc)
op_desc._set_attr(
core.op_proto_and_checker_maker.kOpRoleAttrName(),
core.op_proto_and_checker_maker.OpRole.Backward,
)
# create output var
new_vars = set()
# create new gradient variables
for grad_var_name in op_desc.output_arg_names():
if not (
block.desc.has_var_recursive(grad_var_name.encode())
or grad_var_name == core.empty_var_name()
):
block.desc.var(grad_var_name.encode())
new_vars.add(grad_var_name)
# infer shape and infer dthype
op_desc.check_attrs()
op_desc.infer_var_type(block.desc)
op_desc.infer_shape(block.desc)
for arg in op_desc.output_arg_names():
if arg in new_vars:
_infer_var_data_type_shape_(arg, block)
def _rename_grad_(
block, start_op_idx, grad_to_var, target_grad_map, skip_rename_var_list
):
var_map = copy.copy(target_grad_map) var_map = copy.copy(target_grad_map)
for op_idx in range(start_op_idx, block.desc.op_size()): for op_idx in range(start_op_idx, block.desc.op_size()):
op_desc = block.desc.op(op_idx) op_desc = block.desc.op(op_idx)
...@@ -1667,6 +1746,8 @@ def _rename_grad_(block, start_op_idx, grad_to_var, target_grad_map): ...@@ -1667,6 +1746,8 @@ def _rename_grad_(block, start_op_idx, grad_to_var, target_grad_map):
if "@GRAD" not in name: if "@GRAD" not in name:
continue continue
if block.desc.find_var(name.encode("ascii")): if block.desc.find_var(name.encode("ascii")):
if name in skip_rename_var_list:
continue
new_name = unique_name.generate(name) new_name = unique_name.generate(name)
op_desc._rename_output(name, new_name) op_desc._rename_output(name, new_name)
var_map[name] = new_name var_map[name] = new_name
...@@ -1993,7 +2074,7 @@ def append_backward( ...@@ -1993,7 +2074,7 @@ def append_backward(
# Because append_backward may be called multiple times, # Because append_backward may be called multiple times,
# we need rename the internal gradient variables so that they have # we need rename the internal gradient variables so that they have
# different names. # different names.
_rename_grad_(target_grad_block, fwd_op_num, grad_to_var, {}) _rename_grad_(target_grad_block, fwd_op_num, grad_to_var, {}, [])
_append_backward_vars_( _append_backward_vars_(
target_grad_block, fwd_op_num, grad_to_var, grad_info_map target_grad_block, fwd_op_num, grad_to_var, grad_info_map
...@@ -2297,33 +2378,24 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None): ...@@ -2297,33 +2378,24 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None):
target_grad_map = {} target_grad_map = {}
rename_var_map = {} rename_var_map = {}
skip_rename_var_list = []
for i, grad in enumerate(target_gradients): for i, grad in enumerate(target_gradients):
target = targets[i] target = targets[i]
grad_name = _append_grad_suffix_(target.name) grad_name = _append_grad_suffix_(target.name)
if grad is None: if grad is None:
target_shape = target.name + '_shape'
block.desc.append_op().copy_from(
_create_op_desc_(
"shape",
{'Input': [target.name]},
{"Out": [target_shape]},
{},
)
)
input_grad_names_set.add(target_shape)
op_desc = _create_op_desc_( op_desc = _create_op_desc_(
"fill_constant", "fill_any_like",
{"ShapeTensor": [target_shape]}, {"X": [target.name]},
{"Out": [grad_name]}, {"Out": [grad_name]},
{ {
"shape": target.shape,
"value": 1.0, "value": 1.0,
"dtype": target.dtype, "dtype": target.dtype,
}, },
) )
block.desc.append_op().copy_from(op_desc) block.desc.append_op().copy_from(op_desc)
block.program._sync_with_cpp()
input_grad_names_set.add(grad_name) input_grad_names_set.add(grad_name)
skip_rename_var_list.append(grad_name)
else: else:
if target.block.idx != block_idx or target.block.program != prog: if target.block.idx != block_idx or target.block.program != prog:
raise ValueError("all targets must be in the same block") raise ValueError("all targets must be in the same block")
...@@ -2336,6 +2408,9 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None): ...@@ -2336,6 +2408,9 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None):
input_grad_names_set.add(grad.name) input_grad_names_set.add(grad.name)
rename_var_map[grad_name] = grad.name rename_var_map[grad_name] = grad.name
if core._is_bwd_prim_enabled():
core._set_prim_target_grad_name(target_grad_map)
# For double backward, input_grad_names is used for filter # For double backward, input_grad_names is used for filter
# some non-used gradients op. rename_var_map is used to # some non-used gradients op. rename_var_map is used to
# associate target_grad var name with first grad_op input name. # associate target_grad var name with first grad_op input name.
...@@ -2378,7 +2453,9 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None): ...@@ -2378,7 +2453,9 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None):
# Because calc_gradient may be called multiple times, # Because calc_gradient may be called multiple times,
# we need rename the internal gradient variables so that they have # we need rename the internal gradient variables so that they have
# different names. # different names.
_rename_grad_(block, fwd_op_num, grad_to_var, target_grad_map) _rename_grad_(
block, fwd_op_num, grad_to_var, target_grad_map, skip_rename_var_list
)
_append_backward_vars_(block, fwd_op_num, grad_to_var, grad_info_map) _append_backward_vars_(block, fwd_op_num, grad_to_var, grad_info_map)
prog._sync_with_cpp() prog._sync_with_cpp()
......
...@@ -313,6 +313,7 @@ try: ...@@ -313,6 +313,7 @@ try:
from .libpaddle import __set_fwd_prim_enabled from .libpaddle import __set_fwd_prim_enabled
from .libpaddle import _is_fwd_prim_enabled from .libpaddle import _is_fwd_prim_enabled
from .libpaddle import __set_all_prim_enabled from .libpaddle import __set_all_prim_enabled
from .libpaddle import _set_prim_target_grad_name
# custom devivce # custom devivce
from .libpaddle import _get_current_custom_device_stream from .libpaddle import _get_current_custom_device_stream
......
...@@ -9,3 +9,4 @@ foreach(TEST_OP ${TEST_OPS}) ...@@ -9,3 +9,4 @@ foreach(TEST_OP ${TEST_OPS})
endforeach() endforeach()
add_subdirectory(prim) add_subdirectory(prim)
add_subdirectory(model)
file(
GLOB TEST_OPS
RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}"
"test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS})
endforeach()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import parameterized as param
import paddle
from paddle.fluid import core, framework
@param.parameterized_class(
('name', 'primals', 'stop_gradients', 'cotangents', 'dtype'),
(
(
'test_normal_case',
(np.random.rand(2, 3, 4), np.random.rand(2, 3, 4)),
(False, False),
(np.random.rand(2, 3, 4),),
np.float32,
),
(
'test_broadcast_diff_rank',
(np.random.rand(2, 3, 1, 4), np.random.rand(3, 3, 4)),
(False, False),
(np.random.rand(2, 3, 3, 4),),
np.float32,
),
(
'test_broadcast_same_rank',
(np.random.rand(2, 3, 1, 4), np.random.rand(2, 1, 3, 4)),
(False, False),
(np.random.rand(2, 3, 3, 4),),
np.float32,
),
(
'test_stop_gradient',
(np.random.rand(2, 3, 1, 4), np.random.rand(2, 1, 3, 4)),
(False, True),
(np.random.rand(2, 3, 3, 4),),
np.float32,
),
),
)
class TestMultiplyGradComp(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.primals = tuple(primal.astype(cls.dtype) for primal in cls.primals)
cls.cotangents = tuple(co.astype(cls.dtype) for co in cls.cotangents)
def setUp(self):
paddle.enable_static()
def tearDown(self):
paddle.disable_static()
def as_tuple(self, x):
return (x,) if isinstance(x, framework.Variable) else x
def net(self):
primals, cotangents = self.primals, self.cotangents
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
primals = tuple(
paddle.static.data(f'primal{i}', primal.shape, primal.dtype)
for i, primal in enumerate(primals)
)
for primal, flag in zip(primals, self.stop_gradients):
primal.stop_gradient = flag
cotangents = tuple(
paddle.static.data(f'cotangent{i}', co.shape, co.dtype)
for i, co in enumerate(cotangents)
)
out = self.as_tuple(paddle.tanh(paddle.multiply(*primals)))
grads = paddle.static.gradients(out, primals)
exe = paddle.static.Executor()
exe.run(sp)
return exe.run(
program=mp,
feed={
**{
f'primal{i}': primal
for i, primal in enumerate(self.primals)
},
**{f'cotangent{i}': co for i, co in enumerate(self.cotangents)},
},
fetch_list=[g for g in grads if g is not None],
)
def test_comp(self):
core._set_prim_backward_enabled(True)
actual = self.net()
core._set_prim_backward_enabled(False)
desired = self.net()
self.assertEqual(len(actual), len(desired))
for i, j in zip(actual, desired):
np.testing.assert_allclose(
i,
j,
rtol=1e-6,
atol=0,
)
if __name__ == '__main__':
unittest.main()
...@@ -56,7 +56,7 @@ from paddle.fluid import core, framework ...@@ -56,7 +56,7 @@ from paddle.fluid import core, framework
'test_reduce_axe_empty', 'test_reduce_axe_empty',
(np.random.rand(2, 3, 3, 4), np.random.rand(2, 1, 3, 4)), (np.random.rand(2, 3, 3, 4), np.random.rand(2, 1, 3, 4)),
(False, False), (False, False),
(np.random.rand(2, 1, 3, 1),), (np.random.rand(2, 3, 3, 4),),
np.float32, np.float32,
), ),
), ),
...@@ -91,7 +91,7 @@ class TestMultiplyGradComp(unittest.TestCase): ...@@ -91,7 +91,7 @@ class TestMultiplyGradComp(unittest.TestCase):
for i, co in enumerate(cotangents) for i, co in enumerate(cotangents)
) )
out = self.as_tuple(paddle.multiply(*primals)) out = self.as_tuple(paddle.multiply(*primals))
grads = paddle.static.gradients(out, primals) grads = paddle.static.gradients(out, primals, cotangents)
exe = paddle.static.Executor() exe = paddle.static.Executor()
exe.run(sp) exe.run(sp)
return exe.run( return exe.run(
......
...@@ -75,7 +75,6 @@ class TestGetGradOpDescPrimEnabled(unittest.TestCase): ...@@ -75,7 +75,6 @@ class TestGetGradOpDescPrimEnabled(unittest.TestCase):
self.fwd, self.no_grad_var, self.grad_sub_block self.fwd, self.no_grad_var, self.grad_sub_block
)[0] )[0]
) )
print(actual)
self.assertEquals(actual, self.desired_ops) self.assertEquals(actual, self.desired_ops)
core._set_prim_backward_enabled(False) core._set_prim_backward_enabled(False)
......
...@@ -105,7 +105,7 @@ class TestRunProgram(unittest.TestCase): ...@@ -105,7 +105,7 @@ class TestRunProgram(unittest.TestCase):
) )
backward_program = _add_build_strategy_for( backward_program = _add_build_strategy_for(
program, program,
main_program.desc.block(0).op_size() + 2, main_program.desc.block(0).op_size() + 1,
program.desc.block(0).op_size(), program.desc.block(0).op_size(),
) )
......
...@@ -131,7 +131,7 @@ class RunProgramOpTest(unittest.TestCase): ...@@ -131,7 +131,7 @@ class RunProgramOpTest(unittest.TestCase):
forward_program = _add_build_strategy_for(program, 0, forward_op_num) forward_program = _add_build_strategy_for(program, 0, forward_op_num)
backward_program = _add_build_strategy_for( backward_program = _add_build_strategy_for(
program, program,
forward_op_num + 2 * output_num, forward_op_num + output_num,
program.desc.block(0).op_size(), program.desc.block(0).op_size(),
) )
return forward_program.desc, backward_program.desc return forward_program.desc, backward_program.desc
......
...@@ -576,9 +576,7 @@ class PartialProgramLayer: ...@@ -576,9 +576,7 @@ class PartialProgramLayer:
core.check_and_set_prim_all_enabled() core.check_and_set_prim_all_enabled()
backward.gradients(targets=targets, inputs=[]) backward.gradients(targets=targets, inputs=[])
start_idx = len(main_program.block(0).ops) + 2 * len( start_idx = len(main_program.block(0).ops) + len(self._outputs.tolist())
self._outputs.tolist()
)
self.prepare_gradient_aggregation(start_idx, main_program, program) self.prepare_gradient_aggregation(start_idx, main_program, program)
...@@ -753,7 +751,7 @@ class PartialProgramLayer: ...@@ -753,7 +751,7 @@ class PartialProgramLayer:
): ):
# NOTE(dev): We apply build_strategy for backward firstly to # NOTE(dev): We apply build_strategy for backward firstly to
# avoid skipping more gc variables. # avoid skipping more gc variables.
backward_start_op_index = forward_end_op_index + 2 * len( backward_start_op_index = forward_end_op_index + len(
self._outputs.var_ids self._outputs.var_ids
) )
backward_end_op_index = whole_program.desc.block(0).op_size() backward_end_op_index = whole_program.desc.block(0).op_size()
......
...@@ -1512,12 +1512,11 @@ def _out_grad_names(program_desc, fwd_end_op_index, out_size): ...@@ -1512,12 +1512,11 @@ def _out_grad_names(program_desc, fwd_end_op_index, out_size):
""" """
names = [] names = []
for i in range( for i in range(
fwd_end_op_index + 1, fwd_end_op_index,
min(fwd_end_op_index + 2 * out_size, program_desc.block(0).op_size()), min(fwd_end_op_index + out_size, program_desc.block(0).op_size()),
2,
): ):
op = program_desc.block(0).op(i) op = program_desc.block(0).op(i)
if op.type() == 'fill_constant': if op.type() == 'fill_any_like':
var_name = op.output('Out')[0] var_name = op.output('Out')[0]
names.append(var_name) names.append(var_name)
return names return names
...@@ -373,7 +373,7 @@ class _ProgramHolder: ...@@ -373,7 +373,7 @@ class _ProgramHolder:
@switch_to_static_graph @switch_to_static_graph
def _create_backward_train_program(self): def _create_backward_train_program(self):
whole_program = _build_program_by_desc(self._train_program_desc) whole_program = _build_program_by_desc(self._train_program_desc)
start_op_index = self._infer_program_desc.block(0).op_size() + 2 * len( start_op_index = self._infer_program_desc.block(0).op_size() + len(
self._output_descs self._output_descs
) )
end_op_index = whole_program.desc.block(0).op_size() end_op_index = whole_program.desc.block(0).op_size()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册