未验证 提交 bbca66f2 编写于 作者: J Jiabin Yang 提交者: GitHub

【Prim】Fix slice error and eager comp (#51086)

* fix attrs copy error

* fix bert by fix slice error

* fix op test
上级 41e5667b
......@@ -1840,7 +1840,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
if is_composite_grad_api and next_grad_node_creation_str != '':
next_grad_node_creation_str = f"""
if (!paddle::prim::PrimCommonUtils::IsBwdPrimEnabled()) {{
if (!paddle::prim::PrimCommonUtils::IsEagerPrimEnabled()) {{
{next_grad_node_creation_str}
}}
"""
......@@ -2260,7 +2260,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
# TODO(Ruting):using composite only when we don't have backward kernel in the future.
elif is_composite_grad_api:
grad_function_call_str = f"""
if (paddle::prim::PrimCommonUtils::IsBwdPrimEnabled()) {{
if (paddle::prim::PrimCommonUtils::IsEagerPrimEnabled()) {{
{indent}{composite_grad_api_namespace}{composite_grad_api_name}{composite_template_name}({composite_grad_api_args_str});
VLOG(4) << "Composite api {composite_grad_api_name} is called ";
}}else{{
......
......@@ -423,19 +423,25 @@ class SliceCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
auto dx_ptr = this->GetOutputPtr(&input_grad);
std::string dx_name = this->GetOutputName(input_grad);
auto axes = this->Attr<std::vector<int64_t>>("axes");
auto starts = this->Attr<std::vector<int64_t>>("starts");
auto ends = this->Attr<std::vector<int64_t>>("ends");
auto infer_flags = this->Attr<std::vector<int64_t>>("infer_flags");
auto decrease_axis = this->Attr<std::vector<int64_t>>("decrease_axis");
auto axes = this->Attr<std::vector<int>>("axes");
auto starts = this->Attr<std::vector<int>>("starts");
auto ends = this->Attr<std::vector<int>>("ends");
auto infer_flags = this->Attr<std::vector<int>>("infer_flags");
auto decrease_axis = this->Attr<std::vector<int>>("decrease_axis");
VLOG(6) << "Runing slice_grad composite func";
std::vector<int64_t> new_axes =
std::vector<int64_t>(axes.begin(), axes.end());
std::vector<int64_t> new_infer_flags =
std::vector<int64_t>(infer_flags.begin(), infer_flags.end());
std::vector<int64_t> new_decrease_axis =
std::vector<int64_t>(decrease_axis.begin(), decrease_axis.end());
prim::slice_grad<prim::DescTensor>(input,
out_grad,
axes,
new_axes,
paddle::experimental::IntArray(starts),
paddle::experimental::IntArray(ends),
infer_flags,
decrease_axis,
new_infer_flags,
new_decrease_axis,
dx_ptr);
this->RecoverOutputName(input_grad, dx_name);
}
......@@ -478,6 +484,7 @@ REGISTER_OPERATOR(slice,
ops::SliceOpMaker,
ops::SliceOpGradMaker<paddle::framework::OpDesc>,
ops::SliceOpGradMaker<paddle::imperative::OpBase>,
ops::SliceCompositeGradOpMaker,
ops::SliceOpVarTypeInference);
REGISTER_OPERATOR(slice_grad,
ops::SliceOpGrad,
......
......@@ -704,6 +704,7 @@ void slice_grad(const Tensor& input,
if (input_grad) {
size_t rank = input.dims().size();
auto out_dims = out_grad.dims();
std::vector<int64_t> origin_out_shape;
auto in_dims = input.dims();
auto decrease_size = decrease_axis.size();
......@@ -712,7 +713,7 @@ void slice_grad(const Tensor& input,
// all dims decrease
out_dims = phi::make_ddim(std::vector<int>(decrease_size, 1));
} else {
std::vector<int> origin_out_shape(out_dims.size() + decrease_size, -1);
origin_out_shape.resize(out_dims.size() + decrease_size, -1);
for (size_t i = 0; i < decrease_size; ++i) {
origin_out_shape[decrease_axis[i]] = 1;
}
......@@ -734,7 +735,6 @@ void slice_grad(const Tensor& input,
offsets[i] = 0;
extents[i] = out_dims[i];
}
for (size_t i = 0; i < axes.size(); ++i) {
int axis = axes[i];
int64_t start = starts[i] < 0 ? (starts[i] + in_dims[axis]) : starts[i];
......@@ -747,9 +747,15 @@ void slice_grad(const Tensor& input,
paddings.push_back(offsets[i]);
paddings.push_back((in_dims[i] - out_dims[i]) - offsets[i]);
}
auto out_tmp = pad<T>(out_grad, paddings, 0.0);
set_output<T>(out_tmp, input_grad);
if (decrease_size > 0 &&
(decrease_size != static_cast<size_t>(in_dims.size()))) {
auto out_tmp =
pad<T>(reshape<T>(out_grad, origin_out_shape), paddings, 0.0);
set_output<T>(out_tmp, input_grad);
} else {
auto out_tmp = pad<T>(out_grad, paddings, 0.0);
set_output<T>(out_tmp, input_grad);
}
}
}
......
......@@ -20,5 +20,6 @@ StaticCompositeContext* StaticCompositeContext::static_composite_context_ =
new StaticCompositeContext();
thread_local bool StaticCompositeContext::enable_bwd_prim_ = false;
thread_local bool StaticCompositeContext::enable_fwd_prim_ = false;
thread_local bool StaticCompositeContext::enable_eager_prim_ = false;
} // namespace prim
} // namespace paddle
......@@ -65,6 +65,12 @@ class StaticCompositeContext {
bool IsFwdPrimEnabled() { return enable_fwd_prim_; }
void SetEagerPrimEnabled(bool enable_prim) {
enable_eager_prim_ = enable_prim;
}
bool IsEagerPrimEnabled() { return enable_eager_prim_; }
void SetAllPrimEnabled(bool enable_prim) {
enable_fwd_prim_ = enable_prim;
enable_bwd_prim_ = enable_prim;
......@@ -102,6 +108,7 @@ class StaticCompositeContext {
std::map<std::string, std::string> target_grad_name_;
static thread_local bool enable_bwd_prim_;
static thread_local bool enable_fwd_prim_;
static thread_local bool enable_eager_prim_;
static StaticCompositeContext* static_composite_context_;
DISABLE_COPY_AND_ASSIGN(StaticCompositeContext);
};
......
......@@ -27,6 +27,14 @@ void PrimCommonUtils::SetBwdPrimEnabled(bool enable_prim) {
StaticCompositeContext::Instance().SetBwdPrimEnabled(enable_prim);
}
bool PrimCommonUtils::IsEagerPrimEnabled() {
return StaticCompositeContext::Instance().IsEagerPrimEnabled();
}
void PrimCommonUtils::SetEagerPrimEnabled(bool enable_prim) {
StaticCompositeContext::Instance().SetEagerPrimEnabled(enable_prim);
}
bool PrimCommonUtils::IsFwdPrimEnabled() {
return StaticCompositeContext::Instance().IsFwdPrimEnabled();
}
......
......@@ -23,6 +23,8 @@ class PrimCommonUtils {
public:
static bool IsBwdPrimEnabled();
static void SetBwdPrimEnabled(bool enabled);
static bool IsEagerPrimEnabled();
static void SetEagerPrimEnabled(bool enabled);
static bool IsFwdPrimEnabled();
static void SetFwdPrimEnabled(bool enabled);
static void SetAllPrimEnabled(bool enabled);
......
......@@ -681,6 +681,10 @@ PYBIND11_MODULE(libpaddle, m) {
&paddle::prim::PrimCommonUtils::IsFwdPrimEnabled);
m.def("__set_all_prim_enabled",
&paddle::prim::PrimCommonUtils::SetAllPrimEnabled);
m.def("_is_eager_prim_enabled",
&paddle::prim::PrimCommonUtils::IsEagerPrimEnabled);
m.def("__set_eager_prim_enabled",
&paddle::prim::PrimCommonUtils::SetEagerPrimEnabled);
m.def("_set_prim_target_grad_name",
&paddle::prim::PrimCommonUtils::SetTargetGradName);
m.def("set_num_threads", &platform::SetNumThreads);
......
......@@ -316,6 +316,8 @@ try:
from .libpaddle import __set_fwd_prim_enabled
from .libpaddle import _is_fwd_prim_enabled
from .libpaddle import __set_all_prim_enabled
from .libpaddle import _is_eager_prim_enabled
from .libpaddle import __set_eager_prim_enabled
from .libpaddle import _set_prim_target_grad_name
# custom devivce
......@@ -475,26 +477,36 @@ def _set_prim_forward_blacklist(ops=None):
def _set_prim_backward_enabled(value):
__set_bwd_prim_enabled(bool(value))
print("backward prim enabled: ", bool(_is_bwd_prim_enabled()))
if os.getenv("FLAGS_prim_log") is "1":
print("backward prim enabled: ", bool(_is_bwd_prim_enabled()))
def _set_prim_forward_enabled(value):
__set_fwd_prim_enabled(bool(value))
print("forward prim enabled: ", bool(_is_fwd_prim_enabled()))
if os.getenv("FLAGS_prim_log") is "1":
print("forward prim enabled: ", bool(_is_fwd_prim_enabled()))
def set_prim_eager_enabled(value):
__set_eager_prim_enabled(bool(value))
if os.getenv("FLAGS_prim_log") is "1":
print("eager prim enabled: ", bool(_is_eager_prim_enabled()))
def _set_prim_all_enabled(value):
__set_all_prim_enabled(bool(value))
print(
"all prim enabled: ",
bool(_is_fwd_prim_enabled() and _is_bwd_prim_enabled()),
)
if os.getenv("FLAGS_prim_log") is "1":
print(
"all prim enabled: ",
bool(_is_fwd_prim_enabled() and _is_bwd_prim_enabled()),
)
def __sync_prim_backward_status():
flag_value = os.getenv("FLAGS_prim_backward")
if flag_value is None:
print("backward prim enabled: ", bool(_is_bwd_prim_enabled()))
if os.getenv("FLAGS_prim_log") is "1":
print("backward prim enabled: ", bool(_is_bwd_prim_enabled()))
else:
__sync_stat_with_flag("FLAGS_prim_backward")
......@@ -502,7 +514,8 @@ def __sync_prim_backward_status():
def __sync_prim_forward_status():
flag_value = os.getenv("FLAGS_prim_forward")
if flag_value is None:
print("forward prim enabled: ", bool(_is_fwd_prim_enabled()))
if os.getenv("FLAGS_prim_log") is 1:
print("forward prim enabled: ", bool(_is_fwd_prim_enabled()))
else:
__sync_stat_with_flag("FLAGS_prim_forward")
......
......@@ -207,7 +207,7 @@ class BertPooler(nn.Layer):
class BertModel(nn.Layer):
def __init__(self, config: BertConfig):
def __init__(self, config: BertConfig, to_static):
super(BertModel, self).__init__()
self.config = config
self.pad_token_id = config.pad_token_id
......@@ -247,6 +247,8 @@ class BertModel(nn.Layer):
self.encoder = nn.TransformerEncoder(
encoder_layer, config.num_hidden_layers
)
if to_static:
self.encoder = paddle.jit.to_static(self.encoder)
self.pooler = BertPooler(config)
# self.apply(self.init_weights)
......@@ -364,10 +366,10 @@ class BertModel(nn.Layer):
class Bert(nn.Layer):
def __init__(self):
def __init__(self, to_static):
super(Bert, self).__init__()
config = BertConfig()
self.bert = BertModel(config)
self.bert = BertModel(config, to_static)
self.cls = BertPretrainingHeads(
config,
embedding_weights=self.bert.embeddings.word_embeddings.weight,
......
......@@ -58,7 +58,7 @@ def train(to_static, enable_prim, enable_cinn):
worker_init=None,
)
bert = Bert()
bert = Bert(to_static)
criterion = BertPretrainingCriterion()
if to_static:
# input_sepc = [
......@@ -72,9 +72,6 @@ def train(to_static, enable_prim, enable_cinn):
build_strategy = paddle.static.BuildStrategy()
if enable_cinn:
build_strategy.build_cinn_pass = True
bert = paddle.jit.to_static(
bert, input_sepc, build_strategy=build_strategy
)
optimizer = fluid.optimizer.Adam(parameter_list=bert.parameters())
......
......@@ -58,6 +58,9 @@ class TestPrimFlags(unittest.TestCase):
core.check_and_set_prim_all_enabled()
self.assertFalse(core._is_fwd_prim_enabled())
core.set_prim_eager_enabled(True)
self.assertTrue(core._is_eager_prim_enabled())
with self.assertRaises(TypeError):
core._test_use_sync("aaaa")
......
......@@ -20,7 +20,7 @@ import parameterized as param
import paddle
from paddle.fluid import core
core._set_prim_backward_enabled(True)
core.set_prim_eager_enabled(True)
@param.parameterized_class(
......@@ -61,7 +61,7 @@ class TestAddGradComp(unittest.TestCase):
def test_add_grad_comp(self):
def actual(primal0, primal1):
core._set_prim_backward_enabled(True)
core.set_prim_eager_enabled(True)
paddle.disable_static()
x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False)
y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False)
......@@ -72,7 +72,7 @@ class TestAddGradComp(unittest.TestCase):
return res[0].numpy(), res[1].numpy()
def desired(primal0, primal1):
core._set_prim_backward_enabled(False)
core.set_prim_eager_enabled(False)
paddle.disable_static()
x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False)
y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False)
......@@ -98,7 +98,7 @@ class TestAddGradComp(unittest.TestCase):
rtol=1e-6,
atol=0,
)
core._set_prim_backward_enabled(False)
core.set_prim_eager_enabled(False)
if __name__ == '__main__':
......
......@@ -52,7 +52,7 @@ class TestCastGradComp(unittest.TestCase):
cls.cotangent = cls.cotangent.astype(cls.src_dtype)
def test_cast_grad_comp(self):
core._set_prim_backward_enabled(True)
core.set_prim_eager_enabled(True)
def actual(primal, cotangent):
x = paddle.to_tensor(primal)
......@@ -78,7 +78,7 @@ class TestCastGradComp(unittest.TestCase):
rtol=1e-6,
atol=0,
)
core._set_prim_backward_enabled(False)
core.set_prim_eager_enabled(False)
if __name__ == '__main__':
......
......@@ -20,7 +20,7 @@ import parameterized as param
import paddle
from paddle.fluid import core
core._set_prim_backward_enabled(True)
core.set_prim_eager_enabled(True)
@param.parameterized_class(
......@@ -61,7 +61,7 @@ class TestDivGradComp(unittest.TestCase):
def test_div_grad_comp(self):
def actual(primal0, primal1):
core._set_prim_backward_enabled(True)
core.set_prim_eager_enabled(True)
paddle.disable_static()
x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False)
y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False)
......@@ -72,7 +72,7 @@ class TestDivGradComp(unittest.TestCase):
return res[0].numpy(), res[1].numpy()
def desired(primal0, primal1):
core._set_prim_backward_enabled(False)
core.set_prim_eager_enabled(False)
paddle.disable_static()
x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False)
y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False)
......@@ -98,7 +98,7 @@ class TestDivGradComp(unittest.TestCase):
rtol=1e-6,
atol=0,
)
core._set_prim_backward_enabled(False)
core.set_prim_eager_enabled(False)
if __name__ == '__main__':
......
......@@ -32,14 +32,14 @@ from paddle.fluid import core
class TestExpGradComp(unittest.TestCase):
@classmethod
def setUpClass(cls):
core._set_prim_backward_enabled(True)
core.set_prim_eager_enabled(True)
cls.primal = cls.primal.astype(cls.dtype)
if cls.cotangent is not None:
cls.cotangent = cls.cotangent.astype(cls.dtype)
@classmethod
def tearDownClass(cls):
core._set_prim_backward_enabled(False)
core.set_prim_eager_enabled(False)
def test_exp_grad_comp(self):
def actual(primal, cotangent):
......
......@@ -62,7 +62,7 @@ class TestExpandGradComp(unittest.TestCase):
@classmethod
def tearDownClass(cls):
core._set_prim_backward_enabled(False)
core.set_prim_eager_enabled(False)
def test_comp(self):
def func(primal, cotangent, shape):
......@@ -74,11 +74,11 @@ class TestExpandGradComp(unittest.TestCase):
]
def actual(primal, cotangent, shape):
core._set_prim_backward_enabled(True)
core.set_prim_eager_enabled(True)
return func(primal, cotangent, shape)
def desired(primal, cotangent, shape):
core._set_prim_backward_enabled(False)
core.set_prim_eager_enabled(False)
return func(primal, cotangent, shape)
np.testing.assert_allclose(
......
......@@ -75,11 +75,11 @@ class TestGatherGradComp(unittest.TestCase):
@classmethod
def tearDownClass(cls):
core._set_prim_backward_enabled(False)
core.set_prim_eager_enabled(False)
def test_exp_grad_comp(self):
def actual(primal0, index, axis):
core._set_prim_backward_enabled(True)
core.set_prim_eager_enabled(True)
paddle.disable_static()
x = paddle.to_tensor(
primal0, dtype=primal0.dtype, stop_gradient=False
......@@ -92,7 +92,7 @@ class TestGatherGradComp(unittest.TestCase):
return res[0].numpy()
def desired(primal0, index, axis):
core._set_prim_backward_enabled(False)
core.set_prim_eager_enabled(False)
paddle.disable_static()
x = paddle.to_tensor(
primal0, dtype=primal0.dtype, stop_gradient=False
......
......@@ -20,7 +20,7 @@ import parameterized as param
import paddle
from paddle.fluid import core
core._set_prim_backward_enabled(True)
core.set_prim_eager_enabled(True)
# vector * vector out.shape = (1)
# matrix * vector out.shape = (2)
......@@ -267,7 +267,7 @@ class TestMatmulDoubleGradComp(unittest.TestCase):
def test_matmul_grad_comp(self):
def actual(primal0, primal1, trans_0, trans_1, dtype_):
core._set_prim_backward_enabled(True)
core.set_prim_eager_enabled(True)
paddle.disable_static()
x = paddle.to_tensor(primal0, dtype=dtype_, stop_gradient=False)
y = paddle.to_tensor(primal1, dtype=dtype_, stop_gradient=False)
......@@ -287,7 +287,7 @@ class TestMatmulDoubleGradComp(unittest.TestCase):
)
def desired(primal0, primal1, trans_0, trans_1, dtype_):
core._set_prim_backward_enabled(False)
core.set_prim_eager_enabled(False)
paddle.disable_static()
x = paddle.to_tensor(primal0, dtype=dtype_, stop_gradient=False)
y = paddle.to_tensor(primal1, dtype=dtype_, stop_gradient=False)
......@@ -428,7 +428,7 @@ class TestMatmulTribleGradComp(unittest.TestCase):
def test_matmul_grad_comp(self):
def actual(primal0, primal1, trans_0, trans_1, dtype_):
core._set_prim_backward_enabled(True)
core.set_prim_eager_enabled(True)
paddle.disable_static()
x = paddle.to_tensor(primal0, dtype=dtype_, stop_gradient=False)
y = paddle.to_tensor(primal1, dtype=dtype_, stop_gradient=False)
......@@ -465,7 +465,7 @@ class TestMatmulTribleGradComp(unittest.TestCase):
)
def desired(primal0, primal1, trans_0, trans_1, dtype_):
core._set_prim_backward_enabled(False)
core.set_prim_eager_enabled(False)
paddle.disable_static()
x = paddle.to_tensor(primal0, dtype=dtype_, stop_gradient=False)
y = paddle.to_tensor(primal1, dtype=dtype_, stop_gradient=False)
......@@ -549,7 +549,7 @@ class TestMatmulTribleGradComp(unittest.TestCase):
atol=TOLERANCE[d_type]['atol'],
)
core._set_prim_backward_enabled(False)
core.set_prim_eager_enabled(False)
if __name__ == '__main__':
......
......@@ -81,10 +81,10 @@ class TestMultiplyGradComp(unittest.TestCase):
return [g for g in grads if g is not None]
def test_comp(self):
core._set_prim_backward_enabled(True)
core.set_prim_eager_enabled(True)
actual = self.vjp()
core._set_prim_backward_enabled(False)
core.set_prim_eager_enabled(False)
desired = self.vjp()
for i, j in zip(actual, desired):
......
......@@ -20,7 +20,7 @@ import parameterized as param
import paddle
from paddle.fluid import core
core._set_prim_backward_enabled(True)
core.set_prim_eager_enabled(True)
@param.parameterized_class(
......@@ -42,7 +42,7 @@ class TestReshapeGradComp(unittest.TestCase):
def test_reshape_grad_comp(self):
def actual(primal0, shape):
core._set_prim_backward_enabled(True)
core.set_prim_eager_enabled(True)
paddle.disable_static()
x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False)
x.stop_gradient = False
......@@ -51,7 +51,7 @@ class TestReshapeGradComp(unittest.TestCase):
return res[0].numpy()
def desired(primal0, shape):
core._set_prim_backward_enabled(False)
core.set_prim_eager_enabled(False)
paddle.disable_static()
x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False)
x.stop_gradient = False
......@@ -69,7 +69,7 @@ class TestReshapeGradComp(unittest.TestCase):
rtol=1e-6,
atol=0,
)
core._set_prim_backward_enabled(False)
core.set_prim_eager_enabled(False)
if __name__ == '__main__':
......
......@@ -22,7 +22,7 @@ import parameterized as param
import paddle
from paddle.fluid import core
core._set_prim_backward_enabled(True)
core.set_prim_eager_enabled(True)
@param.parameterized_class(
......@@ -57,7 +57,7 @@ class TestSqrtGradComp(unittest.TestCase):
rtol=1e-6,
atol=0,
)
core._set_prim_backward_enabled(False)
core.set_prim_eager_enabled(False)
if __name__ == '__main__':
......
......@@ -20,7 +20,7 @@ import parameterized as param
import paddle
from paddle.fluid import core
core._set_prim_backward_enabled(True)
core.set_prim_eager_enabled(True)
@param.parameterized_class(
......@@ -61,7 +61,7 @@ class TestSubGradComp(unittest.TestCase):
def test_sub_grad_comp(self):
def actual(primal0, primal1):
core._set_prim_backward_enabled(True)
core.set_prim_eager_enabled(True)
paddle.disable_static()
x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False)
y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False)
......@@ -72,7 +72,7 @@ class TestSubGradComp(unittest.TestCase):
return res[0].numpy(), res[1].numpy()
def desired(primal0, primal1):
core._set_prim_backward_enabled(False)
core.set_prim_eager_enabled(False)
paddle.disable_static()
x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False)
y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False)
......@@ -98,7 +98,7 @@ class TestSubGradComp(unittest.TestCase):
rtol=1e-6,
atol=0,
)
core._set_prim_backward_enabled(False)
core.set_prim_eager_enabled(False)
if __name__ == '__main__':
......
......@@ -21,7 +21,7 @@ from paddle.fluid import core
def actual(primal, cotangent, axis, keep_dim):
core._set_prim_backward_enabled(False)
core.set_prim_eager_enabled(False)
x = paddle.to_tensor(primal, dtype='float32', stop_gradient=False)
v = paddle.to_tensor(cotangent, dtype='float32', stop_gradient=False)
y = paddle.sum(x, axis=axis, keepdim=keep_dim)
......@@ -30,7 +30,7 @@ def actual(primal, cotangent, axis, keep_dim):
def desired(primal, cotangent, axis, keep_dim):
core._set_prim_backward_enabled(True)
core.set_prim_eager_enabled(True)
x = paddle.to_tensor(primal, dtype='float32', stop_gradient=False)
v = paddle.to_tensor(cotangent, dtype='float32', stop_gradient=False)
y = paddle.sum(x, axis=axis, keepdim=keep_dim)
......
......@@ -20,7 +20,7 @@ import parameterized as param
import paddle
from paddle.fluid import core
core._set_prim_backward_enabled(True)
core.set_prim_eager_enabled(True)
@param.parameterized_class(
......@@ -68,7 +68,7 @@ class TestTanhGradComp(unittest.TestCase):
rtol=1e-6,
atol=0,
)
core._set_prim_backward_enabled(False)
core.set_prim_eager_enabled(False)
if __name__ == '__main__':
......
......@@ -20,7 +20,7 @@ import parameterized as param
import paddle
from paddle.fluid import core
core._set_prim_backward_enabled(True)
core.set_prim_eager_enabled(True)
@param.parameterized_class(
......@@ -72,7 +72,7 @@ class TestTransposeGradComp(unittest.TestCase):
def test_transpose_grad_comp(self):
def actual(primal0, shape):
core._set_prim_backward_enabled(True)
core.set_prim_eager_enabled(True)
paddle.disable_static()
x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False)
x.stop_gradient = False
......@@ -81,7 +81,7 @@ class TestTransposeGradComp(unittest.TestCase):
return res[0].numpy()
def desired(primal0, shape):
core._set_prim_backward_enabled(False)
core.set_prim_eager_enabled(False)
paddle.disable_static()
x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False)
x.stop_gradient = False
......@@ -99,7 +99,7 @@ class TestTransposeGradComp(unittest.TestCase):
rtol=1e-6,
atol=0,
)
core._set_prim_backward_enabled(False)
core.set_prim_eager_enabled(False)
if __name__ == '__main__':
......
......@@ -906,7 +906,7 @@ class PrimGradChecker(PrimForwardChecker):
paddle.device.set_device("gpu:0")
atol = self.rev_comp_atol
rtol = self.rev_comp_rtol
core._set_prim_backward_enabled(self.enable_rev_comp)
core.set_prim_eager_enabled(self.enable_rev_comp)
actual_ret = self.get_eager_desire()
# check static forward
if len(actual_ret) != len(self.eager_desire):
......@@ -941,6 +941,7 @@ class PrimGradChecker(PrimForwardChecker):
)
)
raise RuntimeError(msg)
core.set_prim_eager_enabled(False)
def check_static_comp(self):
paddle.enable_static()
......
......@@ -213,9 +213,7 @@ class TestSliceOp_decs_dim_6(TestSliceOp_decs_dim):
class TestSliceOp_starts_ListTensor(OpTest):
def setUp(self):
self.op_type = "slice"
self.prim_op_type = "prim"
self.python_api = paddle.slice
# self.enable_cinn = False
self.config()
starts_tensor = []
......@@ -244,12 +242,10 @@ class TestSliceOp_starts_ListTensor(OpTest):
self.starts_infer = [-1, 0, -1]
def test_check_output(self):
self.check_output(check_prim=True)
self.check_output()
def test_check_grad_normal(self):
self.check_grad(
['Input'], 'Out', max_relative_error=0.006, check_prim=True
)
self.check_grad(['Input'], 'Out', max_relative_error=0.006)
# Situation 2: starts(list, have tensor), ends(list, no tensor)
......@@ -257,7 +253,6 @@ class TestSliceOp_starts_ListTensor(OpTest):
class TestSliceOp_decs_dim_starts_ListTensor(OpTest):
def setUp(self):
self.op_type = "slice"
self.prim_op_type = "prim"
self.python_api = paddle.slice
self.config()
......@@ -290,12 +285,10 @@ class TestSliceOp_decs_dim_starts_ListTensor(OpTest):
self.starts_infer = [1, -1, 2]
def test_check_output(self):
self.check_output(check_prim=True)
self.check_output()
def test_check_grad_normal(self):
self.check_grad(
['Input'], 'Out', max_relative_error=0.006, check_prim=True
)
self.check_grad(['Input'], 'Out', max_relative_error=0.006)
class TestSliceOp_decs_dim_5_starts_ListTensor(
......@@ -318,7 +311,6 @@ class TestSliceOp_decs_dim_5_starts_ListTensor(
class TestSliceOp_decs_dim_starts_OneTensor(OpTest):
def setUp(self):
self.op_type = "slice"
self.prim_op_type = "prim"
self.python_api = paddle.slice
self.config()
self.inputs = {
......@@ -344,12 +336,10 @@ class TestSliceOp_decs_dim_starts_OneTensor(OpTest):
self.out = self.input[1, 0:3, 2:4, :]
def test_check_output(self):
self.check_output(check_prim=True)
self.check_output()
def test_check_grad_normal(self):
self.check_grad(
['Input'], 'Out', max_relative_error=0.006, check_prim=True
)
self.check_grad(['Input'], 'Out', max_relative_error=0.006)
# Situation 4: starts(tensor), ends(tensor)
......@@ -357,7 +347,6 @@ class TestSliceOp_decs_dim_starts_OneTensor(OpTest):
class TestSliceOp_starts_OneTensor_ends_OneTensor(OpTest):
def setUp(self):
self.op_type = "slice"
self.prim_op_type = "prim"
self.python_api = paddle.slice
self.config()
......@@ -383,12 +372,10 @@ class TestSliceOp_starts_OneTensor_ends_OneTensor(OpTest):
self.out = self.input[1:3, 0:3, 2:4, :]
def test_check_output(self):
self.check_output(check_prim=True)
self.check_output()
def test_check_grad_normal(self):
self.check_grad(
['Input'], 'Out', max_relative_error=0.006, check_prim=True
)
self.check_grad(['Input'], 'Out', max_relative_error=0.006)
# Situation 5: starts(tensor), ends(tensor)
......@@ -396,7 +383,6 @@ class TestSliceOp_starts_OneTensor_ends_OneTensor(OpTest):
class TestSliceOp_decs_dim_starts_and_ends_OneTensor(OpTest):
def setUp(self):
self.op_type = "slice"
self.prim_op_type = "prim"
self.python_api = paddle.slice
self.config()
self.inputs = {
......@@ -423,12 +409,10 @@ class TestSliceOp_decs_dim_starts_and_ends_OneTensor(OpTest):
self.out = self.input[1, 0, 2:4, :]
def test_check_output(self):
self.check_output(check_prim=True)
self.check_output()
def test_check_grad_normal(self):
self.check_grad(
['Input'], 'Out', max_relative_error=0.006, check_prim=True
)
self.check_grad(['Input'], 'Out', max_relative_error=0.006)
# Situation 6: starts(tensor), ends(list, have tensor)
......@@ -436,7 +420,6 @@ class TestSliceOp_decs_dim_starts_and_ends_OneTensor(OpTest):
class TestSliceOp_starts_OneTensor_ends_ListTensor(OpTest):
def setUp(self):
self.op_type = "slice"
self.prim_op_type = "prim"
self.python_api = paddle.slice
self.config()
......@@ -470,12 +453,10 @@ class TestSliceOp_starts_OneTensor_ends_ListTensor(OpTest):
self.ends_infer = [-1, 3, 4]
def test_check_output(self):
self.check_output(check_prim=True)
self.check_output()
def test_check_grad_normal(self):
self.check_grad(
['Input'], 'Out', max_relative_error=0.006, check_prim=True
)
self.check_grad(['Input'], 'Out', max_relative_error=0.006)
# Test CUDA float16
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册