提交 a2ffbd53 编写于 作者: H Helin Wang

scaffolding for the new Fluid API

上级 73858547
...@@ -20,6 +20,16 @@ from framework import * ...@@ -20,6 +20,16 @@ from framework import *
import executor import executor
from executor import * from executor import *
import trainer
from trainer import Trainer
from trainer import Event
import inferencer
from inferencer import Inferencer
import params
from params import Params
import io import io
import evaluator import evaluator
import initializer import initializer
...@@ -47,7 +57,8 @@ from parallel_executor import ParallelExecutor ...@@ -47,7 +57,8 @@ from parallel_executor import ParallelExecutor
Tensor = LoDTensor Tensor = LoDTensor
__all__ = framework.__all__ + executor.__all__ + concurrency.__all__ + [ __all__ = framework.__all__ + executor.__all__ + concurrency.__all__ +\
trainer.__all__ + inferencer.__all__ + params.__all__ + [
'io', 'io',
'initializer', 'initializer',
'layers', 'layers',
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = [
'Inferencer',
]
class Inferencer(object):
def __init__(self, network_func, params, place=None):
self.network_func = network_func
self.params = params
self.place = place
def infer(self, inputs):
pass
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import core
__all__ = [
'Params',
]
class Params(object):
def __init__(self, path=None):
self.scope = core.Scope()
if path:
self._load(path)
def _load(self, path):
pass
def save(self, path):
pass
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
__all__ = [
'Event',
'Trainer',
]
class Event(Enum):
BEGIN_EPOCH = 0
END_EPOCH = 1
BEGIN_STEP = 2
END_STEP = 3
def __init__(self):
self.step = 0
self.epoch = 0
self.type = Event.BEGIN_EPOCH
class Trainer(object):
def __init__(self, network_func, optimizer, params=None, place=None):
self.network_func = network_func
self.optimizer = optimizer
self.params = params
self.place = place
def train(self, reader, num_epochs, event_handler):
pass
def test(self, reader):
pass
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册