未验证 提交 bec79a00 编写于 作者: L littletomatodonkey 提交者: GitHub

update static method by trying-catch (#391)

上级 65003348
...@@ -19,6 +19,7 @@ from __future__ import print_function ...@@ -19,6 +19,7 @@ from __future__ import print_function
import os import os
import sys import sys
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from ppcls.modeling import get_architectures from ppcls.modeling import get_architectures
...@@ -134,3 +135,10 @@ def check_function_params(config, key): ...@@ -134,3 +135,10 @@ def check_function_params(config, key):
('params is required in {} config'.format(key)) ('params is required in {} config'.format(key))
assert isinstance(params, dict), \ assert isinstance(params, dict), \
('the params in {} config should be a dict'.format(key)) ('the params in {} config should be a dict'.format(key))
def enable_static_mode():
try:
paddle.enable_static()
except:
pass
...@@ -32,6 +32,7 @@ import program ...@@ -32,6 +32,7 @@ import program
from ppcls.data import Reader from ppcls.data import Reader
from ppcls.utils.config import get_config from ppcls.utils.config import get_config
from ppcls.utils.save_load import init_model from ppcls.utils.save_load import init_model
from ppcls.utils.check import enable_static_mode
from paddle.fluid.incubate.fleet.collective import fleet from paddle.fluid.incubate.fleet.collective import fleet
from paddle.fluid.incubate.fleet.base import role_maker from paddle.fluid.incubate.fleet.base import role_maker
...@@ -84,6 +85,6 @@ def main(args): ...@@ -84,6 +85,6 @@ def main(args):
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static() enable_static_mode()
args = parse_args() args = parse_args()
main(args) main(args)
...@@ -31,6 +31,7 @@ import program ...@@ -31,6 +31,7 @@ import program
from ppcls.data import Reader from ppcls.data import Reader
from ppcls.utils.config import get_config from ppcls.utils.config import get_config
from ppcls.utils.save_load import init_model from ppcls.utils.save_load import init_model
from ppcls.utils.check import enable_static_mode
def parse_args(): def parse_args():
...@@ -77,6 +78,6 @@ def main(args): ...@@ -77,6 +78,6 @@ def main(args):
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static() enable_static_mode()
args = parse_args() args = parse_args()
main(args) main(args)
...@@ -21,6 +21,7 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '../'))) ...@@ -21,6 +21,7 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '../')))
import argparse import argparse
from ppcls.modeling import architectures from ppcls.modeling import architectures
from ppcls.utils.check import enable_static_mode
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -82,5 +83,5 @@ def main(): ...@@ -82,5 +83,5 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static() enable_static_mode()
main() main()
...@@ -24,6 +24,7 @@ import paddle ...@@ -24,6 +24,7 @@ import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from ppcls.modeling import architectures from ppcls.modeling import architectures
from ppcls.utils.check import enable_static_mode
import utils import utils
...@@ -145,5 +146,5 @@ def main(): ...@@ -145,5 +146,5 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static() enable_static_mode()
main() main()
...@@ -11,6 +11,11 @@ ...@@ -11,6 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../')))
import utils import utils
import argparse import argparse
...@@ -19,6 +24,8 @@ import numpy as np ...@@ -19,6 +24,8 @@ import numpy as np
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from ppcls.utils.check import enable_static_mode
def parse_args(): def parse_args():
def str2bool(v): def str2bool(v):
...@@ -100,5 +107,5 @@ def main(): ...@@ -100,5 +107,5 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static() enable_static_mode()
main() main()
...@@ -33,6 +33,7 @@ from paddle.fluid.incubate.fleet.collective import fleet ...@@ -33,6 +33,7 @@ from paddle.fluid.incubate.fleet.collective import fleet
from ppcls.data import Reader from ppcls.data import Reader
from ppcls.utils.config import get_config from ppcls.utils.config import get_config
from ppcls.utils.save_load import init_model, save_model from ppcls.utils.save_load import init_model, save_model
from ppcls.utils.check import enable_static_mode
from ppcls.utils import logger from ppcls.utils import logger
import program import program
...@@ -155,6 +156,6 @@ def main(args): ...@@ -155,6 +156,6 @@ def main(args):
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static() enable_static_mode()
args = parse_args() args = parse_args()
main(args) main(args)
...@@ -31,6 +31,7 @@ import paddle.fluid as fluid ...@@ -31,6 +31,7 @@ import paddle.fluid as fluid
from ppcls.data import Reader from ppcls.data import Reader
from ppcls.utils.config import get_config from ppcls.utils.config import get_config
from ppcls.utils.save_load import init_model, save_model from ppcls.utils.save_load import init_model, save_model
from ppcls.utils.check import enable_static_mode
from ppcls.utils import logger from ppcls.utils import logger
import program import program
...@@ -164,6 +165,6 @@ def main(args): ...@@ -164,6 +165,6 @@ def main(args):
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static() enable_static_mode()
args = parse_args() args = parse_args()
main(args) main(args)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册