提交 d8f30da7 编写于 作者: Y yuyang18

Add PyDataProvider2 DataConverter to swig api.

* fix recommendation prediction also.

ISSUE=4561941



git-svn-id: https://svn.baidu.com/idl/trunk/paddle@1449 1ad973e4-5ce8-4261-8a94-b56d1f490c56
上级 8fe4a338
......@@ -7,3 +7,4 @@ data/train.list
data/test.list
dataprovider_copy_1.py
*.pyc
output
......@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from py_paddle import swig_paddle, DataProviderWrapperConverter
from py_paddle import swig_paddle, DataProviderConverter
from common_utils import *
from paddle.trainer.config_parser import parse_config
......@@ -31,11 +31,11 @@ if __name__ == '__main__':
network = swig_paddle.GradientMachine.createFromConfigProto(conf.model_config)
assert isinstance(network, swig_paddle.GradientMachine)
network.loadParameters(model_path)
with open('meta.bin', 'rb') as f:
with open('./data/meta.bin', 'rb') as f:
meta = pickle.load(f)
headers = list(meta_to_header(meta, 'movie'))
headers.extend(list(meta_to_header(meta, 'user')))
cvt = DataProviderWrapperConverter(True, map(lambda x: x[1], headers))
cvt = DataProviderConverter(headers)
while True:
movie_id = int(raw_input("Input movie_id: "))
user_id = int(raw_input("Input user_id: "))
......@@ -45,7 +45,5 @@ if __name__ == '__main__':
data.extend(movie_meta)
data.append(user_id - 1)
data.extend(user_meta)
data = map(lambda (header, val): val if header[0] else [val],
zip(headers, data))
print "Prediction Score is %.2f" % ((network.forwardTest(cvt([
data]))[0]['value'][0][0] + 5) / 2)
print "Prediction Score is %.2f" % ((network.forwardTest(
cvt.convert([data]))[0]['value'][0][0] + 5) / 2)
......@@ -61,9 +61,7 @@ def construct_feature(name):
slot_dim = each_meta['max']
embedding = embedding_layer(input=data_layer(slot_name,
size=slot_dim),
size=256,
param_attr=ParamAttr(
sparse_update=True))
size=256)
fusion.append(fc_layer(input=embedding,
size=256))
elif type_name == 'embedding':
......
......@@ -102,8 +102,23 @@ static inline void doCopyFromSafely(std::shared_ptr<T1>& dest,
IVector* Arguments::getSlotSequenceStartPositions(size_t idx) const
throw(RangeError) {
auto& a = m->getArg(idx);
if (a.sequenceStartPositions) {
return IVector::createByPaddleVectorPtr(
&a.sequenceStartPositions->getMutableVector(false));
} else {
return nullptr;
}
}
IVector*Arguments::getSlotSubSequenceStartPositions(size_t idx) const
throw (RangeError){
auto& a = m->getArg(idx);
if (a.subSequenceStartPositions) {
return IVector::createByPaddleVectorPtr(
&a.subSequenceStartPositions->getMutableVector(false));
} else {
return nullptr;
}
}
void Arguments::setSlotSequenceStartPositions(size_t idx,
......@@ -113,6 +128,13 @@ void Arguments::setSlotSequenceStartPositions(size_t idx,
a.sequenceStartPositions = std::make_shared<paddle::ICpuGpuVector>(v);
}
void Arguments::setSlotSubSequenceStartPositions(
size_t idx, IVector *vec) throw (RangeError) {
auto& a = m->getArg(idx);
auto& v = m->cast<paddle::IVector>(vec->getSharedPtr());
a.subSequenceStartPositions = std::make_shared<paddle::ICpuGpuVector>(v);
}
IVector* Arguments::getSlotSequenceDim(size_t idx) const throw(RangeError) {
auto& a = m->getArg(idx);
return IVector::createByPaddleVectorPtr(&a.cpuSequenceDims);
......
......@@ -374,6 +374,7 @@ public:
IVector* getSlotIds(size_t idx) const throw(RangeError);
Matrix* getSlotIn(size_t idx) const throw(RangeError);
IVector* getSlotSequenceStartPositions(size_t idx) const throw(RangeError);
IVector* getSlotSubSequenceStartPositions(size_t idx) const throw(RangeError);
IVector* getSlotSequenceDim(size_t idx) const throw(RangeError);
// End Of get functions of Arguments
......@@ -390,6 +391,8 @@ public:
void setSlotIds(size_t idx, IVector* vec) throw(RangeError);
void setSlotSequenceStartPositions(size_t idx,
IVector* vec) throw(RangeError);
void setSlotSubSequenceStartPositions(size_t idx,
IVector* vec) throw (RangeError);
void setSlotSequenceDim(size_t idx, IVector* vec) throw(RangeError);
private:
......
......@@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import swig_paddle
import util
from util import DataProviderWrapperConverter
from dataprovider_converter import DataProviderConverter
__all__ = ['paddle', 'DataProviderWrapperConverter', 'loadParameterFile']
__all__ = ['paddle',
'DataProviderConverter',
'DataProviderWrapperConverter', # for deprecated usage.
'loadParameterFile']
util.monkeypatches()
......@@ -209,6 +209,7 @@ def __monkeypatch_gradient_machine__():
swig_paddle.GradientMachine.getLayerOutputs = getLayerOutputs
def loadGradientMachine(config_filename, model_dir=None):
"""
Load a gradient machine from config file name/path.
......@@ -229,6 +230,7 @@ def loadGradientMachine(config_filename, model_dir=None):
network.loadParameters(model_dir)
return network
def loadParameterFile(fn):
"""
Load Paddle Parameter file to numpy.ndarray
......@@ -257,6 +259,7 @@ def loadParameterFile(fn):
else:
raise swig_paddle.UnsupportError()
class DataProviderWrapperConverter(object):
"""
A class convert DataFormat from PyDataProvider Wrapper to
......@@ -312,7 +315,8 @@ class DataProviderWrapperConverter(object):
self.cols += other
def __call__(self, slot_idx, arg):
mat = swig_paddle.Matrix.createSparse(len(self.indices) - 1, self.dim,
mat = swig_paddle.Matrix.createSparse(len(self.indices) - 1,
self.dim,
len(self.cols), True)
assert isinstance(mat, swig_paddle.Matrix)
mat.sparseCopyFrom(self.indices, self.cols)
......@@ -337,7 +341,8 @@ class DataProviderWrapperConverter(object):
self.values += map(lambda x: x[1], other)
def __call__(self, slot_idx, arg):
mat = swig_paddle.Matrix.createSparse(len(self.indices) - 1, self.dim,
mat = swig_paddle.Matrix.createSparse(len(self.indices) - 1,
self.dim,
len(self.cols), False)
assert isinstance(mat, swig_paddle.Matrix)
mat.sparseCopyFrom(self.indices, self.cols, self.values)
......@@ -373,7 +378,7 @@ class DataProviderWrapperConverter(object):
"""
if argument is None:
argument = swig_paddle.Arguments.createArguments(0)
assert isinstance(argument,swig_paddle.Arguments)
assert isinstance(argument, swig_paddle.Arguments)
argument.resize(len(self.__header__))
values = map(lambda x:
......@@ -394,10 +399,12 @@ class DataProviderWrapperConverter(object):
seq_dim[slot_idx].append(len(sequence))
for slot_idx in xrange(len(self.__header__)):
argument.setSlotSequenceDim(slot_idx, swig_paddle.IVector.create(
argument.setSlotSequenceDim(slot_idx,
swig_paddle.IVector.create(
seq_dim[slot_idx]))
argument.setSlotSequenceStartPositions(
slot_idx, swig_paddle.IVector.create(seq_start_pos[slot_idx]))
slot_idx,
swig_paddle.IVector.create(seq_start_pos[slot_idx]))
else:
for each_sample in wrapper_data:
for raw_data, value in zip(each_sample, values):
......@@ -415,6 +422,7 @@ class DataProviderWrapperConverter(object):
return self.convert(wrapper_data, argument)
def __monkey_patch_protobuf_objects__():
def ParameterConfig_toProto(self):
"""
......@@ -451,7 +459,8 @@ def __monkey_patch_protobuf_objects__():
:return: paddle.OptimizationConfig
"""
assert isinstance(protoObj, paddle.proto.TrainerConfig_pb2.OptimizationConfig)
assert isinstance(protoObj,
paddle.proto.TrainerConfig_pb2.OptimizationConfig)
return swig_paddle.OptimizationConfig.createFromProtoString(
protoObj.SerializeToString())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册