提交 61444d90 编写于 作者: Q qijun

Merge remote-tracking branch 'baidu/develop' into feature/sppnet

...@@ -187,6 +187,15 @@ MatrixPtr Matrix::subMatrix(size_t startRow, size_t endRow, size_t startCol, ...@@ -187,6 +187,15 @@ MatrixPtr Matrix::subMatrix(size_t startRow, size_t endRow, size_t startCol,
trans_, useGpu_); trans_, useGpu_);
} }
void Matrix::setDiag(real value) {
CHECK(data_ != NULL);
CHECK_EQ(height_, width_);
zeroMem();
BaseMatrix diag(height_, 1, stride_ + 1, data_, false, useGpu_);
diag.assign(value);
}
GpuMatrix::GpuMatrix(size_t height, size_t width, bool trans) GpuMatrix::GpuMatrix(size_t height, size_t width, bool trans)
: Matrix(std::make_shared<GpuMemoryHandle>(height * width * sizeof(real)), : Matrix(std::make_shared<GpuMemoryHandle>(height * width * sizeof(real)),
height, width, trans, true) {} height, width, trans, true) {}
...@@ -202,6 +211,7 @@ void GpuMatrix::resetOne() { ...@@ -202,6 +211,7 @@ void GpuMatrix::resetOne() {
CHECK(data_ != NULL); CHECK(data_ != NULL);
one(); one();
} }
void GpuMatrix::resize(size_t newHeight, size_t newWidth) { void GpuMatrix::resize(size_t newHeight, size_t newWidth) {
size_t newSize = newHeight * newWidth; size_t newSize = newHeight * newWidth;
if (NULL == memoryHandle_.get() || if (NULL == memoryHandle_.get() ||
......
...@@ -195,6 +195,8 @@ public: ...@@ -195,6 +195,8 @@ public:
virtual void resetOne() { LOG(FATAL) << "Not implemented"; } virtual void resetOne() { LOG(FATAL) << "Not implemented"; }
void setDiag(real value);
virtual void copyFrom(const Matrix& src) { LOG(FATAL) << "Not implemented"; } virtual void copyFrom(const Matrix& src) { LOG(FATAL) << "Not implemented"; }
virtual void trimFrom(const CpuSparseMatrix& src) { virtual void trimFrom(const CpuSparseMatrix& src) {
...@@ -330,6 +332,7 @@ public: ...@@ -330,6 +332,7 @@ public:
virtual MatrixPtr getInverse() { virtual MatrixPtr getInverse() {
LOG(FATAL) << "Not implemented"; LOG(FATAL) << "Not implemented";
return nullptr;
} }
/** /**
...@@ -1016,6 +1019,7 @@ public: ...@@ -1016,6 +1019,7 @@ public:
void zeroMem(); void zeroMem();
void resetOne(); void resetOne();
void setDiag(real value);
void resize(size_t newHeight, size_t newWidth); void resize(size_t newHeight, size_t newWidth);
void resize(size_t newHeight, size_t newWidth, void resize(size_t newHeight, size_t newWidth,
...@@ -1280,6 +1284,8 @@ public: ...@@ -1280,6 +1284,8 @@ public:
void zeroMem(); void zeroMem();
void resetOne(); void resetOne();
void setDiag(real value);
void resize(size_t newHeight, size_t newWidth); void resize(size_t newHeight, size_t newWidth);
void resize(size_t newHeight, size_t newWidth, void resize(size_t newHeight, size_t newWidth,
size_t newNnz, /* used to allocate space */ size_t newNnz, /* used to allocate space */
......
...@@ -647,20 +647,23 @@ void testMatrixInverse(int height) { ...@@ -647,20 +647,23 @@ void testMatrixInverse(int height) {
MatrixPtr cpuI = std::make_shared<CpuMatrix>(height, height); MatrixPtr cpuI = std::make_shared<CpuMatrix>(height, height);
MatrixPtr gpuI = std::make_shared<GpuMatrix>(height, height); MatrixPtr gpuI = std::make_shared<GpuMatrix>(height, height);
/* Make matrix well conditioned: cpu * cpuT + Identity */
cpu->randomizeUniform(); cpu->randomizeUniform();
MatrixPtr cpuT = cpu->getTranspose();
MatrixPtr outputCheck = std::make_shared<CpuMatrix>(height, height);
outputCheck->mul(cpu, cpuT);
cpu->setDiag(1.0);
cpu->add(*outputCheck);
gpu->copyFrom(*cpu); gpu->copyFrom(*cpu);
cpu->inverse(cpuI, false); cpu->inverse(cpuI, false);
gpu->inverse(gpuI, false); gpu->inverse(gpuI, false);
MatrixPtr outputCheck = std::make_shared<CpuMatrix>(height, height);
outputCheck->copyFrom(*gpuI); outputCheck->copyFrom(*gpuI);
MatrixCheckErr(*cpuI, *outputCheck); MatrixCheckErr(*cpuI, *outputCheck);
outputCheck->mul(cpu, cpuI); outputCheck->mul(cpu, cpuI);
cpu->zeroMem(); cpu->setDiag(1.0);
for (int i = 0; i < height; i++) {
cpu->getRowBuf(i)[i] = 1.0;
}
MatrixCheckErr(*cpu, *outputCheck); MatrixCheckErr(*cpu, *outputCheck);
} }
......
...@@ -592,7 +592,7 @@ class MixedLayerType(LayerOutput): ...@@ -592,7 +592,7 @@ class MixedLayerType(LayerOutput):
def __exit__(self, *args, **kwargs): def __exit__(self, *args, **kwargs):
del args, kwargs # unused parameter to suppress warning del args, kwargs # unused parameter to suppress warning
assert len(self.inputs) != 0 assert len(self.inputs) != 0
MixedLayer( ml = MixedLayer(
name=self.name, name=self.name,
size=self.size, size=self.size,
active_type=self.activation.name, active_type=self.activation.name,
...@@ -600,6 +600,9 @@ class MixedLayerType(LayerOutput): ...@@ -600,6 +600,9 @@ class MixedLayerType(LayerOutput):
inputs=self.inputs, inputs=self.inputs,
**ExtraLayerAttribute.to_kwargs(self.layer_attr) **ExtraLayerAttribute.to_kwargs(self.layer_attr)
) )
# update the size which might be computed inside MixedLayer
# according to the operator's output size
self.size = ml.config.size
@wrap_name_default("mixed") @wrap_name_default("mixed")
...@@ -2104,7 +2107,7 @@ def concat_layer(input, act=None, name=None, layer_attr=None, bias_attr=None): ...@@ -2104,7 +2107,7 @@ def concat_layer(input, act=None, name=None, layer_attr=None, bias_attr=None):
if layer_type == LayerType.CONCAT_LAYER: if layer_type == LayerType.CONCAT_LAYER:
assert not bias_attr assert not bias_attr
Layer( Layer(
name=name, type=layer_type, name=name, type=layer_type,
inputs=[x.name for x in input] if is_concat_layer else input, inputs=[x.name for x in input] if is_concat_layer else input,
...@@ -2682,7 +2685,7 @@ def out_prod_layer(input1, input2, name=None, layer_attr=None): ...@@ -2682,7 +2685,7 @@ def out_prod_layer(input1, input2, name=None, layer_attr=None):
assert isinstance(input1, LayerOutput) assert isinstance(input1, LayerOutput)
assert isinstance(input2, LayerOutput) assert isinstance(input2, LayerOutput)
Layer(name=name, Layer(name=name,
type="out_prod", type=LayerType.OUT_PROD_LAYER,
inputs=[input1.name, input2.name], inputs=[input1.name, input2.name],
**ExtraLayerAttribute.to_kwargs(layer_attr)) **ExtraLayerAttribute.to_kwargs(layer_attr))
return LayerOutput(name=name, return LayerOutput(name=name,
...@@ -2849,7 +2852,7 @@ def beam_search(step, input, bos_id, eos_id, beam_size, ...@@ -2849,7 +2852,7 @@ def beam_search(step, input, bos_id, eos_id, beam_size,
def __cost_input__(input, label, weight=None): def __cost_input__(input, label, weight=None):
""" """
inputs and parents for cost layers. inputs and parents for cost layers.
""" """
ipts = [Input(input.name), Input(label.name)] ipts = [Input(input.name), Input(label.name)]
parents = [input, label] parents = [input, label]
...@@ -2858,7 +2861,7 @@ def __cost_input__(input, label, weight=None): ...@@ -2858,7 +2861,7 @@ def __cost_input__(input, label, weight=None):
ipts.append(Input(weight.name)) ipts.append(Input(weight.name))
parents.append(weight) parents.append(weight)
return ipts, parents return ipts, parents
@wrap_name_default() @wrap_name_default()
@layer_support() @layer_support()
...@@ -2943,7 +2946,7 @@ def classification_cost(input, label, weight=None, name=None, ...@@ -2943,7 +2946,7 @@ def classification_cost(input, label, weight=None, name=None,
def conv_operator(img, filter, filter_size, num_filters, def conv_operator(img, filter, filter_size, num_filters,
num_channel=None, stride=1, padding=0, num_channels=None, stride=1, padding=0,
filter_size_y=None, stride_y=None, padding_y=None): filter_size_y=None, stride_y=None, padding_y=None):
""" """
Different from img_conv_layer, conv_op is an Operator, which can be used Different from img_conv_layer, conv_op is an Operator, which can be used
...@@ -2973,8 +2976,8 @@ def conv_operator(img, filter, filter_size, num_filters, ...@@ -2973,8 +2976,8 @@ def conv_operator(img, filter, filter_size, num_filters,
:type filter_size_y: int :type filter_size_y: int
:param num_filters: channel of output data. :param num_filters: channel of output data.
:type num_filters: int :type num_filters: int
:param num_channel: channel of input data. :param num_channels: channel of input data.
:type num_channel: int :type num_channels: int
:param stride: The x dimension of the stride. :param stride: The x dimension of the stride.
:type stride: int :type stride: int
:param stride_y: The y dimension of the stride. :param stride_y: The y dimension of the stride.
...@@ -2993,19 +2996,19 @@ def conv_operator(img, filter, filter_size, num_filters, ...@@ -2993,19 +2996,19 @@ def conv_operator(img, filter, filter_size, num_filters,
if padding_y is None: if padding_y is None:
padding_y = padding padding_y = padding
if num_channel is None: if num_channels is None:
num_channel = img.num_filters num_channels = img.num_filters
assert isinstance(filter, LayerOutput) assert isinstance(filter, LayerOutput)
if filter.size is not None: if filter.size is not None:
filter.size = filter_size * filter_size_y * num_filters * num_channel filter.size = filter_size * filter_size_y * num_filters * num_channels
op = ConvOperator(input_layer_names=[img.name, filter.name], op = ConvOperator(input_layer_names=[img.name, filter.name],
num_filters=num_filters, num_filters=num_filters,
conv_conf=Conv(filter_size=filter_size, conv_conf=Conv(filter_size=filter_size,
padding=padding, padding=padding,
stride=stride, stride=stride,
channels=num_channel, channels=num_channels,
filter_size_y=filter_size_y, filter_size_y=filter_size_y,
padding_y=padding_y, padding_y=padding_y,
stride_y=stride_y, stride_y=stride_y,
...@@ -3045,8 +3048,8 @@ def conv_projection(input, filter_size, num_filters, ...@@ -3045,8 +3048,8 @@ def conv_projection(input, filter_size, num_filters,
:type filter_size_y: int :type filter_size_y: int
:param num_filters: channel of output data. :param num_filters: channel of output data.
:type num_filters: int :type num_filters: int
:param num_channel: channel of input data. :param num_channels: channel of input data.
:type num_channel: int :type num_channels: int
:param stride: The x dimension of the stride. :param stride: The x dimension of the stride.
:type stride: int :type stride: int
:param stride_y: The y dimension of the stride. :param stride_y: The y dimension of the stride.
...@@ -3537,15 +3540,15 @@ def maxout_layer(input, ...@@ -3537,15 +3540,15 @@ def maxout_layer(input,
- Input: output of a conv layer. - Input: output of a conv layer.
- Output: feature map size same as input. Channel is (input channel) / groups. - Output: feature map size same as input. Channel is (input channel) / groups.
So groups should be larger than 1, and the num of channels should be able So groups should be larger than 1, and the num of channels should be able
to devided by groups. to devided by groups.
Please refer to Paper: Please refer to Paper:
- Maxout Networks: http://www.jmlr.org/proceedings/papers/v28/goodfellow13.pdf - Maxout Networks: http://www.jmlr.org/proceedings/papers/v28/goodfellow13.pdf
- Multi-digit Number Recognition from Street View \ - Multi-digit Number Recognition from Street View \
Imagery using Deep Convolutional Neural Networks: \ Imagery using Deep Convolutional Neural Networks: \
https://arxiv.org/pdf/1312.6082v4.pdf https://arxiv.org/pdf/1312.6082v4.pdf
The simple usage is: The simple usage is:
.. code-block:: python .. code-block:: python
...@@ -3790,9 +3793,9 @@ def nce_layer(input, label, num_classes, weight=None, ...@@ -3790,9 +3793,9 @@ def nce_layer(input, label, num_classes, weight=None,
:param weight: weight layer, can be None(default) :param weight: weight layer, can be None(default)
:type weight: LayerOutput :type weight: LayerOutput
:param num_classes: number of classes. :param num_classes: number of classes.
:type num_classes: int :type num_classes: int
:param num_neg_samples: number of negative samples. Default is 10. :param num_neg_samples: number of negative samples. Default is 10.
:type num_neg_samples: int :type num_neg_samples: int
:param neg_distribution: The distribution for generating the random negative labels. :param neg_distribution: The distribution for generating the random negative labels.
A uniform distribution will be used if not provided. A uniform distribution will be used if not provided.
If not None, its length must be equal to num_classes. If not None, its length must be equal to num_classes.
...@@ -3813,7 +3816,7 @@ def nce_layer(input, label, num_classes, weight=None, ...@@ -3813,7 +3816,7 @@ def nce_layer(input, label, num_classes, weight=None,
assert isinstance(neg_distribution, collections.Sequence) assert isinstance(neg_distribution, collections.Sequence)
assert len(neg_distribution) == num_classes assert len(neg_distribution) == num_classes
assert sum(neg_distribution) == 1 assert sum(neg_distribution) == 1
ipts_for_layer = [] ipts_for_layer = []
parents = [] parents = []
for each_input in input: for each_input in input:
......
...@@ -35,7 +35,7 @@ flt = data_layer(name='filter', size=3*3*1*64) ...@@ -35,7 +35,7 @@ flt = data_layer(name='filter', size=3*3*1*64)
with mixed_layer() as m7: with mixed_layer() as m7:
m7 += conv_operator(img=img, filter=flt, num_filters=64, m7 += conv_operator(img=img, filter=flt, num_filters=64,
num_channel=1, filter_size=3) num_channels=1, filter_size=3)
end = mixed_layer(input=[full_matrix_projection(input=m5), end = mixed_layer(input=[full_matrix_projection(input=m5),
trans_full_matrix_projection(input=m6), trans_full_matrix_projection(input=m6),
......
...@@ -29,9 +29,11 @@ z1 = mixed_layer(act=LinearActivation(), ...@@ -29,9 +29,11 @@ z1 = mixed_layer(act=LinearActivation(),
filter=y1, filter=y1,
filter_size=1, filter_size=1,
num_filters=5, num_filters=5,
num_channel=5, num_channels=5,
stride=1)]) stride=1)])
assert z1.size > 0
y2 = fc_layer(input=y, size=15) y2 = fc_layer(input=y, size=15)
cos1 = cos_sim(a=x1, b=y1) cos1 = cos_sim(a=x1, b=y1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册