未验证 提交 bbca66f2 编写于 作者: J Jiabin Yang 提交者: GitHub

【Prim】Fix slice error and eager comp (#51086)

* fix attrs copy error

* fix bert by fix slice error

* fix op test
上级 41e5667b
...@@ -1840,7 +1840,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -1840,7 +1840,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
if is_composite_grad_api and next_grad_node_creation_str != '': if is_composite_grad_api and next_grad_node_creation_str != '':
next_grad_node_creation_str = f""" next_grad_node_creation_str = f"""
if (!paddle::prim::PrimCommonUtils::IsBwdPrimEnabled()) {{ if (!paddle::prim::PrimCommonUtils::IsEagerPrimEnabled()) {{
{next_grad_node_creation_str} {next_grad_node_creation_str}
}} }}
""" """
...@@ -2260,7 +2260,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -2260,7 +2260,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
# TODO(Ruting):using composite only when we don't have backward kernel in the future. # TODO(Ruting):using composite only when we don't have backward kernel in the future.
elif is_composite_grad_api: elif is_composite_grad_api:
grad_function_call_str = f""" grad_function_call_str = f"""
if (paddle::prim::PrimCommonUtils::IsBwdPrimEnabled()) {{ if (paddle::prim::PrimCommonUtils::IsEagerPrimEnabled()) {{
{indent}{composite_grad_api_namespace}{composite_grad_api_name}{composite_template_name}({composite_grad_api_args_str}); {indent}{composite_grad_api_namespace}{composite_grad_api_name}{composite_template_name}({composite_grad_api_args_str});
VLOG(4) << "Composite api {composite_grad_api_name} is called "; VLOG(4) << "Composite api {composite_grad_api_name} is called ";
}}else{{ }}else{{
......
...@@ -423,19 +423,25 @@ class SliceCompositeGradOpMaker : public prim::CompositeGradOpMakerBase { ...@@ -423,19 +423,25 @@ class SliceCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
auto dx_ptr = this->GetOutputPtr(&input_grad); auto dx_ptr = this->GetOutputPtr(&input_grad);
std::string dx_name = this->GetOutputName(input_grad); std::string dx_name = this->GetOutputName(input_grad);
auto axes = this->Attr<std::vector<int64_t>>("axes"); auto axes = this->Attr<std::vector<int>>("axes");
auto starts = this->Attr<std::vector<int64_t>>("starts"); auto starts = this->Attr<std::vector<int>>("starts");
auto ends = this->Attr<std::vector<int64_t>>("ends"); auto ends = this->Attr<std::vector<int>>("ends");
auto infer_flags = this->Attr<std::vector<int64_t>>("infer_flags"); auto infer_flags = this->Attr<std::vector<int>>("infer_flags");
auto decrease_axis = this->Attr<std::vector<int64_t>>("decrease_axis"); auto decrease_axis = this->Attr<std::vector<int>>("decrease_axis");
VLOG(6) << "Runing slice_grad composite func"; VLOG(6) << "Runing slice_grad composite func";
std::vector<int64_t> new_axes =
std::vector<int64_t>(axes.begin(), axes.end());
std::vector<int64_t> new_infer_flags =
std::vector<int64_t>(infer_flags.begin(), infer_flags.end());
std::vector<int64_t> new_decrease_axis =
std::vector<int64_t>(decrease_axis.begin(), decrease_axis.end());
prim::slice_grad<prim::DescTensor>(input, prim::slice_grad<prim::DescTensor>(input,
out_grad, out_grad,
axes, new_axes,
paddle::experimental::IntArray(starts), paddle::experimental::IntArray(starts),
paddle::experimental::IntArray(ends), paddle::experimental::IntArray(ends),
infer_flags, new_infer_flags,
decrease_axis, new_decrease_axis,
dx_ptr); dx_ptr);
this->RecoverOutputName(input_grad, dx_name); this->RecoverOutputName(input_grad, dx_name);
} }
...@@ -478,6 +484,7 @@ REGISTER_OPERATOR(slice, ...@@ -478,6 +484,7 @@ REGISTER_OPERATOR(slice,
ops::SliceOpMaker, ops::SliceOpMaker,
ops::SliceOpGradMaker<paddle::framework::OpDesc>, ops::SliceOpGradMaker<paddle::framework::OpDesc>,
ops::SliceOpGradMaker<paddle::imperative::OpBase>, ops::SliceOpGradMaker<paddle::imperative::OpBase>,
ops::SliceCompositeGradOpMaker,
ops::SliceOpVarTypeInference); ops::SliceOpVarTypeInference);
REGISTER_OPERATOR(slice_grad, REGISTER_OPERATOR(slice_grad,
ops::SliceOpGrad, ops::SliceOpGrad,
......
...@@ -704,6 +704,7 @@ void slice_grad(const Tensor& input, ...@@ -704,6 +704,7 @@ void slice_grad(const Tensor& input,
if (input_grad) { if (input_grad) {
size_t rank = input.dims().size(); size_t rank = input.dims().size();
auto out_dims = out_grad.dims(); auto out_dims = out_grad.dims();
std::vector<int64_t> origin_out_shape;
auto in_dims = input.dims(); auto in_dims = input.dims();
auto decrease_size = decrease_axis.size(); auto decrease_size = decrease_axis.size();
...@@ -712,7 +713,7 @@ void slice_grad(const Tensor& input, ...@@ -712,7 +713,7 @@ void slice_grad(const Tensor& input,
// all dims decrease // all dims decrease
out_dims = phi::make_ddim(std::vector<int>(decrease_size, 1)); out_dims = phi::make_ddim(std::vector<int>(decrease_size, 1));
} else { } else {
std::vector<int> origin_out_shape(out_dims.size() + decrease_size, -1); origin_out_shape.resize(out_dims.size() + decrease_size, -1);
for (size_t i = 0; i < decrease_size; ++i) { for (size_t i = 0; i < decrease_size; ++i) {
origin_out_shape[decrease_axis[i]] = 1; origin_out_shape[decrease_axis[i]] = 1;
} }
...@@ -734,7 +735,6 @@ void slice_grad(const Tensor& input, ...@@ -734,7 +735,6 @@ void slice_grad(const Tensor& input,
offsets[i] = 0; offsets[i] = 0;
extents[i] = out_dims[i]; extents[i] = out_dims[i];
} }
for (size_t i = 0; i < axes.size(); ++i) { for (size_t i = 0; i < axes.size(); ++i) {
int axis = axes[i]; int axis = axes[i];
int64_t start = starts[i] < 0 ? (starts[i] + in_dims[axis]) : starts[i]; int64_t start = starts[i] < 0 ? (starts[i] + in_dims[axis]) : starts[i];
...@@ -747,9 +747,15 @@ void slice_grad(const Tensor& input, ...@@ -747,9 +747,15 @@ void slice_grad(const Tensor& input,
paddings.push_back(offsets[i]); paddings.push_back(offsets[i]);
paddings.push_back((in_dims[i] - out_dims[i]) - offsets[i]); paddings.push_back((in_dims[i] - out_dims[i]) - offsets[i]);
} }
if (decrease_size > 0 &&
auto out_tmp = pad<T>(out_grad, paddings, 0.0); (decrease_size != static_cast<size_t>(in_dims.size()))) {
set_output<T>(out_tmp, input_grad); auto out_tmp =
pad<T>(reshape<T>(out_grad, origin_out_shape), paddings, 0.0);
set_output<T>(out_tmp, input_grad);
} else {
auto out_tmp = pad<T>(out_grad, paddings, 0.0);
set_output<T>(out_tmp, input_grad);
}
} }
} }
......
...@@ -20,5 +20,6 @@ StaticCompositeContext* StaticCompositeContext::static_composite_context_ = ...@@ -20,5 +20,6 @@ StaticCompositeContext* StaticCompositeContext::static_composite_context_ =
new StaticCompositeContext(); new StaticCompositeContext();
thread_local bool StaticCompositeContext::enable_bwd_prim_ = false; thread_local bool StaticCompositeContext::enable_bwd_prim_ = false;
thread_local bool StaticCompositeContext::enable_fwd_prim_ = false; thread_local bool StaticCompositeContext::enable_fwd_prim_ = false;
thread_local bool StaticCompositeContext::enable_eager_prim_ = false;
} // namespace prim } // namespace prim
} // namespace paddle } // namespace paddle
...@@ -65,6 +65,12 @@ class StaticCompositeContext { ...@@ -65,6 +65,12 @@ class StaticCompositeContext {
bool IsFwdPrimEnabled() { return enable_fwd_prim_; } bool IsFwdPrimEnabled() { return enable_fwd_prim_; }
void SetEagerPrimEnabled(bool enable_prim) {
enable_eager_prim_ = enable_prim;
}
bool IsEagerPrimEnabled() { return enable_eager_prim_; }
void SetAllPrimEnabled(bool enable_prim) { void SetAllPrimEnabled(bool enable_prim) {
enable_fwd_prim_ = enable_prim; enable_fwd_prim_ = enable_prim;
enable_bwd_prim_ = enable_prim; enable_bwd_prim_ = enable_prim;
...@@ -102,6 +108,7 @@ class StaticCompositeContext { ...@@ -102,6 +108,7 @@ class StaticCompositeContext {
std::map<std::string, std::string> target_grad_name_; std::map<std::string, std::string> target_grad_name_;
static thread_local bool enable_bwd_prim_; static thread_local bool enable_bwd_prim_;
static thread_local bool enable_fwd_prim_; static thread_local bool enable_fwd_prim_;
static thread_local bool enable_eager_prim_;
static StaticCompositeContext* static_composite_context_; static StaticCompositeContext* static_composite_context_;
DISABLE_COPY_AND_ASSIGN(StaticCompositeContext); DISABLE_COPY_AND_ASSIGN(StaticCompositeContext);
}; };
......
...@@ -27,6 +27,14 @@ void PrimCommonUtils::SetBwdPrimEnabled(bool enable_prim) { ...@@ -27,6 +27,14 @@ void PrimCommonUtils::SetBwdPrimEnabled(bool enable_prim) {
StaticCompositeContext::Instance().SetBwdPrimEnabled(enable_prim); StaticCompositeContext::Instance().SetBwdPrimEnabled(enable_prim);
} }
bool PrimCommonUtils::IsEagerPrimEnabled() {
return StaticCompositeContext::Instance().IsEagerPrimEnabled();
}
void PrimCommonUtils::SetEagerPrimEnabled(bool enable_prim) {
StaticCompositeContext::Instance().SetEagerPrimEnabled(enable_prim);
}
bool PrimCommonUtils::IsFwdPrimEnabled() { bool PrimCommonUtils::IsFwdPrimEnabled() {
return StaticCompositeContext::Instance().IsFwdPrimEnabled(); return StaticCompositeContext::Instance().IsFwdPrimEnabled();
} }
......
...@@ -23,6 +23,8 @@ class PrimCommonUtils { ...@@ -23,6 +23,8 @@ class PrimCommonUtils {
public: public:
static bool IsBwdPrimEnabled(); static bool IsBwdPrimEnabled();
static void SetBwdPrimEnabled(bool enabled); static void SetBwdPrimEnabled(bool enabled);
static bool IsEagerPrimEnabled();
static void SetEagerPrimEnabled(bool enabled);
static bool IsFwdPrimEnabled(); static bool IsFwdPrimEnabled();
static void SetFwdPrimEnabled(bool enabled); static void SetFwdPrimEnabled(bool enabled);
static void SetAllPrimEnabled(bool enabled); static void SetAllPrimEnabled(bool enabled);
......
...@@ -681,6 +681,10 @@ PYBIND11_MODULE(libpaddle, m) { ...@@ -681,6 +681,10 @@ PYBIND11_MODULE(libpaddle, m) {
&paddle::prim::PrimCommonUtils::IsFwdPrimEnabled); &paddle::prim::PrimCommonUtils::IsFwdPrimEnabled);
m.def("__set_all_prim_enabled", m.def("__set_all_prim_enabled",
&paddle::prim::PrimCommonUtils::SetAllPrimEnabled); &paddle::prim::PrimCommonUtils::SetAllPrimEnabled);
m.def("_is_eager_prim_enabled",
&paddle::prim::PrimCommonUtils::IsEagerPrimEnabled);
m.def("__set_eager_prim_enabled",
&paddle::prim::PrimCommonUtils::SetEagerPrimEnabled);
m.def("_set_prim_target_grad_name", m.def("_set_prim_target_grad_name",
&paddle::prim::PrimCommonUtils::SetTargetGradName); &paddle::prim::PrimCommonUtils::SetTargetGradName);
m.def("set_num_threads", &platform::SetNumThreads); m.def("set_num_threads", &platform::SetNumThreads);
......
...@@ -316,6 +316,8 @@ try: ...@@ -316,6 +316,8 @@ try:
from .libpaddle import __set_fwd_prim_enabled from .libpaddle import __set_fwd_prim_enabled
from .libpaddle import _is_fwd_prim_enabled from .libpaddle import _is_fwd_prim_enabled
from .libpaddle import __set_all_prim_enabled from .libpaddle import __set_all_prim_enabled
from .libpaddle import _is_eager_prim_enabled
from .libpaddle import __set_eager_prim_enabled
from .libpaddle import _set_prim_target_grad_name from .libpaddle import _set_prim_target_grad_name
# custom devivce # custom devivce
...@@ -475,26 +477,36 @@ def _set_prim_forward_blacklist(ops=None): ...@@ -475,26 +477,36 @@ def _set_prim_forward_blacklist(ops=None):
def _set_prim_backward_enabled(value): def _set_prim_backward_enabled(value):
__set_bwd_prim_enabled(bool(value)) __set_bwd_prim_enabled(bool(value))
print("backward prim enabled: ", bool(_is_bwd_prim_enabled())) if os.getenv("FLAGS_prim_log") is "1":
print("backward prim enabled: ", bool(_is_bwd_prim_enabled()))
def _set_prim_forward_enabled(value): def _set_prim_forward_enabled(value):
__set_fwd_prim_enabled(bool(value)) __set_fwd_prim_enabled(bool(value))
print("forward prim enabled: ", bool(_is_fwd_prim_enabled())) if os.getenv("FLAGS_prim_log") is "1":
print("forward prim enabled: ", bool(_is_fwd_prim_enabled()))
def set_prim_eager_enabled(value):
__set_eager_prim_enabled(bool(value))
if os.getenv("FLAGS_prim_log") is "1":
print("eager prim enabled: ", bool(_is_eager_prim_enabled()))
def _set_prim_all_enabled(value): def _set_prim_all_enabled(value):
__set_all_prim_enabled(bool(value)) __set_all_prim_enabled(bool(value))
print( if os.getenv("FLAGS_prim_log") is "1":
"all prim enabled: ", print(
bool(_is_fwd_prim_enabled() and _is_bwd_prim_enabled()), "all prim enabled: ",
) bool(_is_fwd_prim_enabled() and _is_bwd_prim_enabled()),
)
def __sync_prim_backward_status(): def __sync_prim_backward_status():
flag_value = os.getenv("FLAGS_prim_backward") flag_value = os.getenv("FLAGS_prim_backward")
if flag_value is None: if flag_value is None:
print("backward prim enabled: ", bool(_is_bwd_prim_enabled())) if os.getenv("FLAGS_prim_log") is "1":
print("backward prim enabled: ", bool(_is_bwd_prim_enabled()))
else: else:
__sync_stat_with_flag("FLAGS_prim_backward") __sync_stat_with_flag("FLAGS_prim_backward")
...@@ -502,7 +514,8 @@ def __sync_prim_backward_status(): ...@@ -502,7 +514,8 @@ def __sync_prim_backward_status():
def __sync_prim_forward_status(): def __sync_prim_forward_status():
flag_value = os.getenv("FLAGS_prim_forward") flag_value = os.getenv("FLAGS_prim_forward")
if flag_value is None: if flag_value is None:
print("forward prim enabled: ", bool(_is_fwd_prim_enabled())) if os.getenv("FLAGS_prim_log") is 1:
print("forward prim enabled: ", bool(_is_fwd_prim_enabled()))
else: else:
__sync_stat_with_flag("FLAGS_prim_forward") __sync_stat_with_flag("FLAGS_prim_forward")
......
...@@ -207,7 +207,7 @@ class BertPooler(nn.Layer): ...@@ -207,7 +207,7 @@ class BertPooler(nn.Layer):
class BertModel(nn.Layer): class BertModel(nn.Layer):
def __init__(self, config: BertConfig): def __init__(self, config: BertConfig, to_static):
super(BertModel, self).__init__() super(BertModel, self).__init__()
self.config = config self.config = config
self.pad_token_id = config.pad_token_id self.pad_token_id = config.pad_token_id
...@@ -247,6 +247,8 @@ class BertModel(nn.Layer): ...@@ -247,6 +247,8 @@ class BertModel(nn.Layer):
self.encoder = nn.TransformerEncoder( self.encoder = nn.TransformerEncoder(
encoder_layer, config.num_hidden_layers encoder_layer, config.num_hidden_layers
) )
if to_static:
self.encoder = paddle.jit.to_static(self.encoder)
self.pooler = BertPooler(config) self.pooler = BertPooler(config)
# self.apply(self.init_weights) # self.apply(self.init_weights)
...@@ -364,10 +366,10 @@ class BertModel(nn.Layer): ...@@ -364,10 +366,10 @@ class BertModel(nn.Layer):
class Bert(nn.Layer): class Bert(nn.Layer):
def __init__(self): def __init__(self, to_static):
super(Bert, self).__init__() super(Bert, self).__init__()
config = BertConfig() config = BertConfig()
self.bert = BertModel(config) self.bert = BertModel(config, to_static)
self.cls = BertPretrainingHeads( self.cls = BertPretrainingHeads(
config, config,
embedding_weights=self.bert.embeddings.word_embeddings.weight, embedding_weights=self.bert.embeddings.word_embeddings.weight,
......
...@@ -58,7 +58,7 @@ def train(to_static, enable_prim, enable_cinn): ...@@ -58,7 +58,7 @@ def train(to_static, enable_prim, enable_cinn):
worker_init=None, worker_init=None,
) )
bert = Bert() bert = Bert(to_static)
criterion = BertPretrainingCriterion() criterion = BertPretrainingCriterion()
if to_static: if to_static:
# input_sepc = [ # input_sepc = [
...@@ -72,9 +72,6 @@ def train(to_static, enable_prim, enable_cinn): ...@@ -72,9 +72,6 @@ def train(to_static, enable_prim, enable_cinn):
build_strategy = paddle.static.BuildStrategy() build_strategy = paddle.static.BuildStrategy()
if enable_cinn: if enable_cinn:
build_strategy.build_cinn_pass = True build_strategy.build_cinn_pass = True
bert = paddle.jit.to_static(
bert, input_sepc, build_strategy=build_strategy
)
optimizer = fluid.optimizer.Adam(parameter_list=bert.parameters()) optimizer = fluid.optimizer.Adam(parameter_list=bert.parameters())
......
...@@ -58,6 +58,9 @@ class TestPrimFlags(unittest.TestCase): ...@@ -58,6 +58,9 @@ class TestPrimFlags(unittest.TestCase):
core.check_and_set_prim_all_enabled() core.check_and_set_prim_all_enabled()
self.assertFalse(core._is_fwd_prim_enabled()) self.assertFalse(core._is_fwd_prim_enabled())
core.set_prim_eager_enabled(True)
self.assertTrue(core._is_eager_prim_enabled())
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
core._test_use_sync("aaaa") core._test_use_sync("aaaa")
......
...@@ -20,7 +20,7 @@ import parameterized as param ...@@ -20,7 +20,7 @@ import parameterized as param
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
core._set_prim_backward_enabled(True) core.set_prim_eager_enabled(True)
@param.parameterized_class( @param.parameterized_class(
...@@ -61,7 +61,7 @@ class TestAddGradComp(unittest.TestCase): ...@@ -61,7 +61,7 @@ class TestAddGradComp(unittest.TestCase):
def test_add_grad_comp(self): def test_add_grad_comp(self):
def actual(primal0, primal1): def actual(primal0, primal1):
core._set_prim_backward_enabled(True) core.set_prim_eager_enabled(True)
paddle.disable_static() paddle.disable_static()
x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False) x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False)
y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False) y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False)
...@@ -72,7 +72,7 @@ class TestAddGradComp(unittest.TestCase): ...@@ -72,7 +72,7 @@ class TestAddGradComp(unittest.TestCase):
return res[0].numpy(), res[1].numpy() return res[0].numpy(), res[1].numpy()
def desired(primal0, primal1): def desired(primal0, primal1):
core._set_prim_backward_enabled(False) core.set_prim_eager_enabled(False)
paddle.disable_static() paddle.disable_static()
x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False) x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False)
y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False) y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False)
...@@ -98,7 +98,7 @@ class TestAddGradComp(unittest.TestCase): ...@@ -98,7 +98,7 @@ class TestAddGradComp(unittest.TestCase):
rtol=1e-6, rtol=1e-6,
atol=0, atol=0,
) )
core._set_prim_backward_enabled(False) core.set_prim_eager_enabled(False)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -52,7 +52,7 @@ class TestCastGradComp(unittest.TestCase): ...@@ -52,7 +52,7 @@ class TestCastGradComp(unittest.TestCase):
cls.cotangent = cls.cotangent.astype(cls.src_dtype) cls.cotangent = cls.cotangent.astype(cls.src_dtype)
def test_cast_grad_comp(self): def test_cast_grad_comp(self):
core._set_prim_backward_enabled(True) core.set_prim_eager_enabled(True)
def actual(primal, cotangent): def actual(primal, cotangent):
x = paddle.to_tensor(primal) x = paddle.to_tensor(primal)
...@@ -78,7 +78,7 @@ class TestCastGradComp(unittest.TestCase): ...@@ -78,7 +78,7 @@ class TestCastGradComp(unittest.TestCase):
rtol=1e-6, rtol=1e-6,
atol=0, atol=0,
) )
core._set_prim_backward_enabled(False) core.set_prim_eager_enabled(False)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -20,7 +20,7 @@ import parameterized as param ...@@ -20,7 +20,7 @@ import parameterized as param
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
core._set_prim_backward_enabled(True) core.set_prim_eager_enabled(True)
@param.parameterized_class( @param.parameterized_class(
...@@ -61,7 +61,7 @@ class TestDivGradComp(unittest.TestCase): ...@@ -61,7 +61,7 @@ class TestDivGradComp(unittest.TestCase):
def test_div_grad_comp(self): def test_div_grad_comp(self):
def actual(primal0, primal1): def actual(primal0, primal1):
core._set_prim_backward_enabled(True) core.set_prim_eager_enabled(True)
paddle.disable_static() paddle.disable_static()
x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False) x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False)
y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False) y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False)
...@@ -72,7 +72,7 @@ class TestDivGradComp(unittest.TestCase): ...@@ -72,7 +72,7 @@ class TestDivGradComp(unittest.TestCase):
return res[0].numpy(), res[1].numpy() return res[0].numpy(), res[1].numpy()
def desired(primal0, primal1): def desired(primal0, primal1):
core._set_prim_backward_enabled(False) core.set_prim_eager_enabled(False)
paddle.disable_static() paddle.disable_static()
x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False) x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False)
y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False) y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False)
...@@ -98,7 +98,7 @@ class TestDivGradComp(unittest.TestCase): ...@@ -98,7 +98,7 @@ class TestDivGradComp(unittest.TestCase):
rtol=1e-6, rtol=1e-6,
atol=0, atol=0,
) )
core._set_prim_backward_enabled(False) core.set_prim_eager_enabled(False)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -32,14 +32,14 @@ from paddle.fluid import core ...@@ -32,14 +32,14 @@ from paddle.fluid import core
class TestExpGradComp(unittest.TestCase): class TestExpGradComp(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
core._set_prim_backward_enabled(True) core.set_prim_eager_enabled(True)
cls.primal = cls.primal.astype(cls.dtype) cls.primal = cls.primal.astype(cls.dtype)
if cls.cotangent is not None: if cls.cotangent is not None:
cls.cotangent = cls.cotangent.astype(cls.dtype) cls.cotangent = cls.cotangent.astype(cls.dtype)
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
core._set_prim_backward_enabled(False) core.set_prim_eager_enabled(False)
def test_exp_grad_comp(self): def test_exp_grad_comp(self):
def actual(primal, cotangent): def actual(primal, cotangent):
......
...@@ -62,7 +62,7 @@ class TestExpandGradComp(unittest.TestCase): ...@@ -62,7 +62,7 @@ class TestExpandGradComp(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
core._set_prim_backward_enabled(False) core.set_prim_eager_enabled(False)
def test_comp(self): def test_comp(self):
def func(primal, cotangent, shape): def func(primal, cotangent, shape):
...@@ -74,11 +74,11 @@ class TestExpandGradComp(unittest.TestCase): ...@@ -74,11 +74,11 @@ class TestExpandGradComp(unittest.TestCase):
] ]
def actual(primal, cotangent, shape): def actual(primal, cotangent, shape):
core._set_prim_backward_enabled(True) core.set_prim_eager_enabled(True)
return func(primal, cotangent, shape) return func(primal, cotangent, shape)
def desired(primal, cotangent, shape): def desired(primal, cotangent, shape):
core._set_prim_backward_enabled(False) core.set_prim_eager_enabled(False)
return func(primal, cotangent, shape) return func(primal, cotangent, shape)
np.testing.assert_allclose( np.testing.assert_allclose(
......
...@@ -75,11 +75,11 @@ class TestGatherGradComp(unittest.TestCase): ...@@ -75,11 +75,11 @@ class TestGatherGradComp(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
core._set_prim_backward_enabled(False) core.set_prim_eager_enabled(False)
def test_exp_grad_comp(self): def test_exp_grad_comp(self):
def actual(primal0, index, axis): def actual(primal0, index, axis):
core._set_prim_backward_enabled(True) core.set_prim_eager_enabled(True)
paddle.disable_static() paddle.disable_static()
x = paddle.to_tensor( x = paddle.to_tensor(
primal0, dtype=primal0.dtype, stop_gradient=False primal0, dtype=primal0.dtype, stop_gradient=False
...@@ -92,7 +92,7 @@ class TestGatherGradComp(unittest.TestCase): ...@@ -92,7 +92,7 @@ class TestGatherGradComp(unittest.TestCase):
return res[0].numpy() return res[0].numpy()
def desired(primal0, index, axis): def desired(primal0, index, axis):
core._set_prim_backward_enabled(False) core.set_prim_eager_enabled(False)
paddle.disable_static() paddle.disable_static()
x = paddle.to_tensor( x = paddle.to_tensor(
primal0, dtype=primal0.dtype, stop_gradient=False primal0, dtype=primal0.dtype, stop_gradient=False
......
...@@ -20,7 +20,7 @@ import parameterized as param ...@@ -20,7 +20,7 @@ import parameterized as param
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
core._set_prim_backward_enabled(True) core.set_prim_eager_enabled(True)
# vector * vector out.shape = (1) # vector * vector out.shape = (1)
# matrix * vector out.shape = (2) # matrix * vector out.shape = (2)
...@@ -267,7 +267,7 @@ class TestMatmulDoubleGradComp(unittest.TestCase): ...@@ -267,7 +267,7 @@ class TestMatmulDoubleGradComp(unittest.TestCase):
def test_matmul_grad_comp(self): def test_matmul_grad_comp(self):
def actual(primal0, primal1, trans_0, trans_1, dtype_): def actual(primal0, primal1, trans_0, trans_1, dtype_):
core._set_prim_backward_enabled(True) core.set_prim_eager_enabled(True)
paddle.disable_static() paddle.disable_static()
x = paddle.to_tensor(primal0, dtype=dtype_, stop_gradient=False) x = paddle.to_tensor(primal0, dtype=dtype_, stop_gradient=False)
y = paddle.to_tensor(primal1, dtype=dtype_, stop_gradient=False) y = paddle.to_tensor(primal1, dtype=dtype_, stop_gradient=False)
...@@ -287,7 +287,7 @@ class TestMatmulDoubleGradComp(unittest.TestCase): ...@@ -287,7 +287,7 @@ class TestMatmulDoubleGradComp(unittest.TestCase):
) )
def desired(primal0, primal1, trans_0, trans_1, dtype_): def desired(primal0, primal1, trans_0, trans_1, dtype_):
core._set_prim_backward_enabled(False) core.set_prim_eager_enabled(False)
paddle.disable_static() paddle.disable_static()
x = paddle.to_tensor(primal0, dtype=dtype_, stop_gradient=False) x = paddle.to_tensor(primal0, dtype=dtype_, stop_gradient=False)
y = paddle.to_tensor(primal1, dtype=dtype_, stop_gradient=False) y = paddle.to_tensor(primal1, dtype=dtype_, stop_gradient=False)
...@@ -428,7 +428,7 @@ class TestMatmulTribleGradComp(unittest.TestCase): ...@@ -428,7 +428,7 @@ class TestMatmulTribleGradComp(unittest.TestCase):
def test_matmul_grad_comp(self): def test_matmul_grad_comp(self):
def actual(primal0, primal1, trans_0, trans_1, dtype_): def actual(primal0, primal1, trans_0, trans_1, dtype_):
core._set_prim_backward_enabled(True) core.set_prim_eager_enabled(True)
paddle.disable_static() paddle.disable_static()
x = paddle.to_tensor(primal0, dtype=dtype_, stop_gradient=False) x = paddle.to_tensor(primal0, dtype=dtype_, stop_gradient=False)
y = paddle.to_tensor(primal1, dtype=dtype_, stop_gradient=False) y = paddle.to_tensor(primal1, dtype=dtype_, stop_gradient=False)
...@@ -465,7 +465,7 @@ class TestMatmulTribleGradComp(unittest.TestCase): ...@@ -465,7 +465,7 @@ class TestMatmulTribleGradComp(unittest.TestCase):
) )
def desired(primal0, primal1, trans_0, trans_1, dtype_): def desired(primal0, primal1, trans_0, trans_1, dtype_):
core._set_prim_backward_enabled(False) core.set_prim_eager_enabled(False)
paddle.disable_static() paddle.disable_static()
x = paddle.to_tensor(primal0, dtype=dtype_, stop_gradient=False) x = paddle.to_tensor(primal0, dtype=dtype_, stop_gradient=False)
y = paddle.to_tensor(primal1, dtype=dtype_, stop_gradient=False) y = paddle.to_tensor(primal1, dtype=dtype_, stop_gradient=False)
...@@ -549,7 +549,7 @@ class TestMatmulTribleGradComp(unittest.TestCase): ...@@ -549,7 +549,7 @@ class TestMatmulTribleGradComp(unittest.TestCase):
atol=TOLERANCE[d_type]['atol'], atol=TOLERANCE[d_type]['atol'],
) )
core._set_prim_backward_enabled(False) core.set_prim_eager_enabled(False)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -81,10 +81,10 @@ class TestMultiplyGradComp(unittest.TestCase): ...@@ -81,10 +81,10 @@ class TestMultiplyGradComp(unittest.TestCase):
return [g for g in grads if g is not None] return [g for g in grads if g is not None]
def test_comp(self): def test_comp(self):
core._set_prim_backward_enabled(True) core.set_prim_eager_enabled(True)
actual = self.vjp() actual = self.vjp()
core._set_prim_backward_enabled(False) core.set_prim_eager_enabled(False)
desired = self.vjp() desired = self.vjp()
for i, j in zip(actual, desired): for i, j in zip(actual, desired):
......
...@@ -20,7 +20,7 @@ import parameterized as param ...@@ -20,7 +20,7 @@ import parameterized as param
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
core._set_prim_backward_enabled(True) core.set_prim_eager_enabled(True)
@param.parameterized_class( @param.parameterized_class(
...@@ -42,7 +42,7 @@ class TestReshapeGradComp(unittest.TestCase): ...@@ -42,7 +42,7 @@ class TestReshapeGradComp(unittest.TestCase):
def test_reshape_grad_comp(self): def test_reshape_grad_comp(self):
def actual(primal0, shape): def actual(primal0, shape):
core._set_prim_backward_enabled(True) core.set_prim_eager_enabled(True)
paddle.disable_static() paddle.disable_static()
x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False) x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False)
x.stop_gradient = False x.stop_gradient = False
...@@ -51,7 +51,7 @@ class TestReshapeGradComp(unittest.TestCase): ...@@ -51,7 +51,7 @@ class TestReshapeGradComp(unittest.TestCase):
return res[0].numpy() return res[0].numpy()
def desired(primal0, shape): def desired(primal0, shape):
core._set_prim_backward_enabled(False) core.set_prim_eager_enabled(False)
paddle.disable_static() paddle.disable_static()
x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False) x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False)
x.stop_gradient = False x.stop_gradient = False
...@@ -69,7 +69,7 @@ class TestReshapeGradComp(unittest.TestCase): ...@@ -69,7 +69,7 @@ class TestReshapeGradComp(unittest.TestCase):
rtol=1e-6, rtol=1e-6,
atol=0, atol=0,
) )
core._set_prim_backward_enabled(False) core.set_prim_eager_enabled(False)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -22,7 +22,7 @@ import parameterized as param ...@@ -22,7 +22,7 @@ import parameterized as param
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
core._set_prim_backward_enabled(True) core.set_prim_eager_enabled(True)
@param.parameterized_class( @param.parameterized_class(
...@@ -57,7 +57,7 @@ class TestSqrtGradComp(unittest.TestCase): ...@@ -57,7 +57,7 @@ class TestSqrtGradComp(unittest.TestCase):
rtol=1e-6, rtol=1e-6,
atol=0, atol=0,
) )
core._set_prim_backward_enabled(False) core.set_prim_eager_enabled(False)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -20,7 +20,7 @@ import parameterized as param ...@@ -20,7 +20,7 @@ import parameterized as param
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
core._set_prim_backward_enabled(True) core.set_prim_eager_enabled(True)
@param.parameterized_class( @param.parameterized_class(
...@@ -61,7 +61,7 @@ class TestSubGradComp(unittest.TestCase): ...@@ -61,7 +61,7 @@ class TestSubGradComp(unittest.TestCase):
def test_sub_grad_comp(self): def test_sub_grad_comp(self):
def actual(primal0, primal1): def actual(primal0, primal1):
core._set_prim_backward_enabled(True) core.set_prim_eager_enabled(True)
paddle.disable_static() paddle.disable_static()
x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False) x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False)
y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False) y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False)
...@@ -72,7 +72,7 @@ class TestSubGradComp(unittest.TestCase): ...@@ -72,7 +72,7 @@ class TestSubGradComp(unittest.TestCase):
return res[0].numpy(), res[1].numpy() return res[0].numpy(), res[1].numpy()
def desired(primal0, primal1): def desired(primal0, primal1):
core._set_prim_backward_enabled(False) core.set_prim_eager_enabled(False)
paddle.disable_static() paddle.disable_static()
x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False) x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False)
y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False) y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False)
...@@ -98,7 +98,7 @@ class TestSubGradComp(unittest.TestCase): ...@@ -98,7 +98,7 @@ class TestSubGradComp(unittest.TestCase):
rtol=1e-6, rtol=1e-6,
atol=0, atol=0,
) )
core._set_prim_backward_enabled(False) core.set_prim_eager_enabled(False)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -21,7 +21,7 @@ from paddle.fluid import core ...@@ -21,7 +21,7 @@ from paddle.fluid import core
def actual(primal, cotangent, axis, keep_dim): def actual(primal, cotangent, axis, keep_dim):
core._set_prim_backward_enabled(False) core.set_prim_eager_enabled(False)
x = paddle.to_tensor(primal, dtype='float32', stop_gradient=False) x = paddle.to_tensor(primal, dtype='float32', stop_gradient=False)
v = paddle.to_tensor(cotangent, dtype='float32', stop_gradient=False) v = paddle.to_tensor(cotangent, dtype='float32', stop_gradient=False)
y = paddle.sum(x, axis=axis, keepdim=keep_dim) y = paddle.sum(x, axis=axis, keepdim=keep_dim)
...@@ -30,7 +30,7 @@ def actual(primal, cotangent, axis, keep_dim): ...@@ -30,7 +30,7 @@ def actual(primal, cotangent, axis, keep_dim):
def desired(primal, cotangent, axis, keep_dim): def desired(primal, cotangent, axis, keep_dim):
core._set_prim_backward_enabled(True) core.set_prim_eager_enabled(True)
x = paddle.to_tensor(primal, dtype='float32', stop_gradient=False) x = paddle.to_tensor(primal, dtype='float32', stop_gradient=False)
v = paddle.to_tensor(cotangent, dtype='float32', stop_gradient=False) v = paddle.to_tensor(cotangent, dtype='float32', stop_gradient=False)
y = paddle.sum(x, axis=axis, keepdim=keep_dim) y = paddle.sum(x, axis=axis, keepdim=keep_dim)
......
...@@ -20,7 +20,7 @@ import parameterized as param ...@@ -20,7 +20,7 @@ import parameterized as param
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
core._set_prim_backward_enabled(True) core.set_prim_eager_enabled(True)
@param.parameterized_class( @param.parameterized_class(
...@@ -68,7 +68,7 @@ class TestTanhGradComp(unittest.TestCase): ...@@ -68,7 +68,7 @@ class TestTanhGradComp(unittest.TestCase):
rtol=1e-6, rtol=1e-6,
atol=0, atol=0,
) )
core._set_prim_backward_enabled(False) core.set_prim_eager_enabled(False)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -20,7 +20,7 @@ import parameterized as param ...@@ -20,7 +20,7 @@ import parameterized as param
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
core._set_prim_backward_enabled(True) core.set_prim_eager_enabled(True)
@param.parameterized_class( @param.parameterized_class(
...@@ -72,7 +72,7 @@ class TestTransposeGradComp(unittest.TestCase): ...@@ -72,7 +72,7 @@ class TestTransposeGradComp(unittest.TestCase):
def test_transpose_grad_comp(self): def test_transpose_grad_comp(self):
def actual(primal0, shape): def actual(primal0, shape):
core._set_prim_backward_enabled(True) core.set_prim_eager_enabled(True)
paddle.disable_static() paddle.disable_static()
x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False) x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False)
x.stop_gradient = False x.stop_gradient = False
...@@ -81,7 +81,7 @@ class TestTransposeGradComp(unittest.TestCase): ...@@ -81,7 +81,7 @@ class TestTransposeGradComp(unittest.TestCase):
return res[0].numpy() return res[0].numpy()
def desired(primal0, shape): def desired(primal0, shape):
core._set_prim_backward_enabled(False) core.set_prim_eager_enabled(False)
paddle.disable_static() paddle.disable_static()
x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False) x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False)
x.stop_gradient = False x.stop_gradient = False
...@@ -99,7 +99,7 @@ class TestTransposeGradComp(unittest.TestCase): ...@@ -99,7 +99,7 @@ class TestTransposeGradComp(unittest.TestCase):
rtol=1e-6, rtol=1e-6,
atol=0, atol=0,
) )
core._set_prim_backward_enabled(False) core.set_prim_eager_enabled(False)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -906,7 +906,7 @@ class PrimGradChecker(PrimForwardChecker): ...@@ -906,7 +906,7 @@ class PrimGradChecker(PrimForwardChecker):
paddle.device.set_device("gpu:0") paddle.device.set_device("gpu:0")
atol = self.rev_comp_atol atol = self.rev_comp_atol
rtol = self.rev_comp_rtol rtol = self.rev_comp_rtol
core._set_prim_backward_enabled(self.enable_rev_comp) core.set_prim_eager_enabled(self.enable_rev_comp)
actual_ret = self.get_eager_desire() actual_ret = self.get_eager_desire()
# check static forward # check static forward
if len(actual_ret) != len(self.eager_desire): if len(actual_ret) != len(self.eager_desire):
...@@ -941,6 +941,7 @@ class PrimGradChecker(PrimForwardChecker): ...@@ -941,6 +941,7 @@ class PrimGradChecker(PrimForwardChecker):
) )
) )
raise RuntimeError(msg) raise RuntimeError(msg)
core.set_prim_eager_enabled(False)
def check_static_comp(self): def check_static_comp(self):
paddle.enable_static() paddle.enable_static()
......
...@@ -213,9 +213,7 @@ class TestSliceOp_decs_dim_6(TestSliceOp_decs_dim): ...@@ -213,9 +213,7 @@ class TestSliceOp_decs_dim_6(TestSliceOp_decs_dim):
class TestSliceOp_starts_ListTensor(OpTest): class TestSliceOp_starts_ListTensor(OpTest):
def setUp(self): def setUp(self):
self.op_type = "slice" self.op_type = "slice"
self.prim_op_type = "prim"
self.python_api = paddle.slice self.python_api = paddle.slice
# self.enable_cinn = False
self.config() self.config()
starts_tensor = [] starts_tensor = []
...@@ -244,12 +242,10 @@ class TestSliceOp_starts_ListTensor(OpTest): ...@@ -244,12 +242,10 @@ class TestSliceOp_starts_ListTensor(OpTest):
self.starts_infer = [-1, 0, -1] self.starts_infer = [-1, 0, -1]
def test_check_output(self): def test_check_output(self):
self.check_output(check_prim=True) self.check_output()
def test_check_grad_normal(self): def test_check_grad_normal(self):
self.check_grad( self.check_grad(['Input'], 'Out', max_relative_error=0.006)
['Input'], 'Out', max_relative_error=0.006, check_prim=True
)
# Situation 2: starts(list, have tensor), ends(list, no tensor) # Situation 2: starts(list, have tensor), ends(list, no tensor)
...@@ -257,7 +253,6 @@ class TestSliceOp_starts_ListTensor(OpTest): ...@@ -257,7 +253,6 @@ class TestSliceOp_starts_ListTensor(OpTest):
class TestSliceOp_decs_dim_starts_ListTensor(OpTest): class TestSliceOp_decs_dim_starts_ListTensor(OpTest):
def setUp(self): def setUp(self):
self.op_type = "slice" self.op_type = "slice"
self.prim_op_type = "prim"
self.python_api = paddle.slice self.python_api = paddle.slice
self.config() self.config()
...@@ -290,12 +285,10 @@ class TestSliceOp_decs_dim_starts_ListTensor(OpTest): ...@@ -290,12 +285,10 @@ class TestSliceOp_decs_dim_starts_ListTensor(OpTest):
self.starts_infer = [1, -1, 2] self.starts_infer = [1, -1, 2]
def test_check_output(self): def test_check_output(self):
self.check_output(check_prim=True) self.check_output()
def test_check_grad_normal(self): def test_check_grad_normal(self):
self.check_grad( self.check_grad(['Input'], 'Out', max_relative_error=0.006)
['Input'], 'Out', max_relative_error=0.006, check_prim=True
)
class TestSliceOp_decs_dim_5_starts_ListTensor( class TestSliceOp_decs_dim_5_starts_ListTensor(
...@@ -318,7 +311,6 @@ class TestSliceOp_decs_dim_5_starts_ListTensor( ...@@ -318,7 +311,6 @@ class TestSliceOp_decs_dim_5_starts_ListTensor(
class TestSliceOp_decs_dim_starts_OneTensor(OpTest): class TestSliceOp_decs_dim_starts_OneTensor(OpTest):
def setUp(self): def setUp(self):
self.op_type = "slice" self.op_type = "slice"
self.prim_op_type = "prim"
self.python_api = paddle.slice self.python_api = paddle.slice
self.config() self.config()
self.inputs = { self.inputs = {
...@@ -344,12 +336,10 @@ class TestSliceOp_decs_dim_starts_OneTensor(OpTest): ...@@ -344,12 +336,10 @@ class TestSliceOp_decs_dim_starts_OneTensor(OpTest):
self.out = self.input[1, 0:3, 2:4, :] self.out = self.input[1, 0:3, 2:4, :]
def test_check_output(self): def test_check_output(self):
self.check_output(check_prim=True) self.check_output()
def test_check_grad_normal(self): def test_check_grad_normal(self):
self.check_grad( self.check_grad(['Input'], 'Out', max_relative_error=0.006)
['Input'], 'Out', max_relative_error=0.006, check_prim=True
)
# Situation 4: starts(tensor), ends(tensor) # Situation 4: starts(tensor), ends(tensor)
...@@ -357,7 +347,6 @@ class TestSliceOp_decs_dim_starts_OneTensor(OpTest): ...@@ -357,7 +347,6 @@ class TestSliceOp_decs_dim_starts_OneTensor(OpTest):
class TestSliceOp_starts_OneTensor_ends_OneTensor(OpTest): class TestSliceOp_starts_OneTensor_ends_OneTensor(OpTest):
def setUp(self): def setUp(self):
self.op_type = "slice" self.op_type = "slice"
self.prim_op_type = "prim"
self.python_api = paddle.slice self.python_api = paddle.slice
self.config() self.config()
...@@ -383,12 +372,10 @@ class TestSliceOp_starts_OneTensor_ends_OneTensor(OpTest): ...@@ -383,12 +372,10 @@ class TestSliceOp_starts_OneTensor_ends_OneTensor(OpTest):
self.out = self.input[1:3, 0:3, 2:4, :] self.out = self.input[1:3, 0:3, 2:4, :]
def test_check_output(self): def test_check_output(self):
self.check_output(check_prim=True) self.check_output()
def test_check_grad_normal(self): def test_check_grad_normal(self):
self.check_grad( self.check_grad(['Input'], 'Out', max_relative_error=0.006)
['Input'], 'Out', max_relative_error=0.006, check_prim=True
)
# Situation 5: starts(tensor), ends(tensor) # Situation 5: starts(tensor), ends(tensor)
...@@ -396,7 +383,6 @@ class TestSliceOp_starts_OneTensor_ends_OneTensor(OpTest): ...@@ -396,7 +383,6 @@ class TestSliceOp_starts_OneTensor_ends_OneTensor(OpTest):
class TestSliceOp_decs_dim_starts_and_ends_OneTensor(OpTest): class TestSliceOp_decs_dim_starts_and_ends_OneTensor(OpTest):
def setUp(self): def setUp(self):
self.op_type = "slice" self.op_type = "slice"
self.prim_op_type = "prim"
self.python_api = paddle.slice self.python_api = paddle.slice
self.config() self.config()
self.inputs = { self.inputs = {
...@@ -423,12 +409,10 @@ class TestSliceOp_decs_dim_starts_and_ends_OneTensor(OpTest): ...@@ -423,12 +409,10 @@ class TestSliceOp_decs_dim_starts_and_ends_OneTensor(OpTest):
self.out = self.input[1, 0, 2:4, :] self.out = self.input[1, 0, 2:4, :]
def test_check_output(self): def test_check_output(self):
self.check_output(check_prim=True) self.check_output()
def test_check_grad_normal(self): def test_check_grad_normal(self):
self.check_grad( self.check_grad(['Input'], 'Out', max_relative_error=0.006)
['Input'], 'Out', max_relative_error=0.006, check_prim=True
)
# Situation 6: starts(tensor), ends(list, have tensor) # Situation 6: starts(tensor), ends(list, have tensor)
...@@ -436,7 +420,6 @@ class TestSliceOp_decs_dim_starts_and_ends_OneTensor(OpTest): ...@@ -436,7 +420,6 @@ class TestSliceOp_decs_dim_starts_and_ends_OneTensor(OpTest):
class TestSliceOp_starts_OneTensor_ends_ListTensor(OpTest): class TestSliceOp_starts_OneTensor_ends_ListTensor(OpTest):
def setUp(self): def setUp(self):
self.op_type = "slice" self.op_type = "slice"
self.prim_op_type = "prim"
self.python_api = paddle.slice self.python_api = paddle.slice
self.config() self.config()
...@@ -470,12 +453,10 @@ class TestSliceOp_starts_OneTensor_ends_ListTensor(OpTest): ...@@ -470,12 +453,10 @@ class TestSliceOp_starts_OneTensor_ends_ListTensor(OpTest):
self.ends_infer = [-1, 3, 4] self.ends_infer = [-1, 3, 4]
def test_check_output(self): def test_check_output(self):
self.check_output(check_prim=True) self.check_output()
def test_check_grad_normal(self): def test_check_grad_normal(self):
self.check_grad( self.check_grad(['Input'], 'Out', max_relative_error=0.006)
['Input'], 'Out', max_relative_error=0.006, check_prim=True
)
# Test CUDA float16 # Test CUDA float16
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册