提交 55eb2fcf 编写于 作者: H Haonan

format correction

上级 f5995300
...@@ -40,15 +40,13 @@ void RotateLayer::forward(PassType passType) { ...@@ -40,15 +40,13 @@ void RotateLayer::forward(PassType passType) {
MatrixPtr outV = getOutputValue(); MatrixPtr outV = getOutputValue();
for (int b = 0; b < batchSize_; b ++) { for (int b = 0; b < batchSize_; b++) {
MatrixPtr inputSample MatrixPtr inputSample = Matrix::create(input->getData() + b * sampleSize_,
= Matrix::create(input->getData() + b * sampleSize_,
sampleHeight_, sampleHeight_,
sampleWidth_, sampleWidth_,
false, false,
useGpu_); useGpu_);
MatrixPtr outputSample MatrixPtr outputSample = Matrix::create(outV->getData() + b * sampleSize_,
= Matrix::create(outV->getData() + b * sampleSize_,
sampleWidth_, sampleWidth_,
sampleHeight_, sampleHeight_,
false, false,
...@@ -71,21 +69,21 @@ void RotateLayer::backward(const UpdateCallback& callback) { ...@@ -71,21 +69,21 @@ void RotateLayer::backward(const UpdateCallback& callback) {
// the grad should be rotated in the reverse direction // the grad should be rotated in the reverse direction
MatrixPtr preGrad = getInputGrad(0); MatrixPtr preGrad = getInputGrad(0);
for (int b = 0; b < batchSize_; b ++) { for (int b = 0; b < batchSize_; b++) {
MatrixPtr inputSampleGrad MatrixPtr inputSampleGrad =
= Matrix::create(preGrad->getData() + b * sampleSize_, Matrix::create(preGrad->getData() + b * sampleSize_,
sampleHeight_, sampleHeight_,
sampleWidth_, sampleWidth_,
false, false,
useGpu_); useGpu_);
MatrixPtr outputSampleGrad MatrixPtr outputSampleGrad =
= Matrix::create(outputGrad->getData() + b * sampleSize_, Matrix::create(outputGrad->getData() + b * sampleSize_,
sampleWidth_, sampleWidth_,
sampleHeight_, sampleHeight_,
false, false,
useGpu_); useGpu_);
MatrixPtr tmpGrad MatrixPtr tmpGrad =
= Matrix::create(sampleHeight_, sampleWidth_, false, useGpu_); Matrix::create(sampleHeight_, sampleWidth_, false, useGpu_);
outputSampleGrad->rotate(tmpGrad, false, false); outputSampleGrad->rotate(tmpGrad, false, false);
inputSampleGrad->add(*tmpGrad); inputSampleGrad->add(*tmpGrad);
} }
......
...@@ -1833,7 +1833,6 @@ class PoolLayer(LayerBase): ...@@ -1833,7 +1833,6 @@ class PoolLayer(LayerBase):
pool_conf.channels) pool_conf.channels)
@config_layer('spp') @config_layer('spp')
class SpatialPyramidPoolLayer(LayerBase): class SpatialPyramidPoolLayer(LayerBase):
def __init__(self, name, inputs, **xargs): def __init__(self, name, inputs, **xargs):
......
...@@ -1707,12 +1707,14 @@ def rotate_layer(input, height, name=None, layer_attr=None): ...@@ -1707,12 +1707,14 @@ def rotate_layer(input, height, name=None, layer_attr=None):
:rtype: LayerOutput :rtype: LayerOutput
""" """
assert isinstance(input, LayerOutput) assert isinstance(input, LayerOutput)
l = Layer(name=name, l = Layer(
name=name,
height=height, height=height,
type=LayerType.ROTATE_LAYER, type=LayerType.ROTATE_LAYER,
inputs=[input.name], inputs=[input.name],
**ExtraLayerAttribute.to_kwargs(layer_attr)) **ExtraLayerAttribute.to_kwargs(layer_attr))
return LayerOutput(name=name, return LayerOutput(
name=name,
layer_type=LayerType.ROTATE_LAYER, layer_type=LayerType.ROTATE_LAYER,
parents=[input], parents=[input],
size=l.config.size) size=l.config.size)
...@@ -1750,8 +1752,9 @@ def flip_layer(input, height, name=None, layer_attr=None): ...@@ -1750,8 +1752,9 @@ def flip_layer(input, height, name=None, layer_attr=None):
:rtype: LayerOutput :rtype: LayerOutput
""" """
assert isinstance(input, LayerOutput) assert isinstance(input, LayerOutput)
return rotate_layer(input=rotate_layer(input=input, return rotate_layer(
height=height), input=rotate_layer(
input=input, height=height),
height=height, height=height,
name=name, name=name,
layer_attr=layer_attr) layer_attr=layer_attr)
......
...@@ -39,10 +39,8 @@ z1 = mixed_layer( ...@@ -39,10 +39,8 @@ z1 = mixed_layer(
assert z1.size > 0 assert z1.size > 0
y2 = fc_layer(input=y, size=15) y2 = fc_layer(input=y, size=15)
z2 = rotate_layer(input=y2, z2 = rotate_layer(input=y2, height=5)
height=5) z3 = flip_layer(input=y2, height=3)
z3 = flip_layer(input=y2,
height=3)
cos1 = cos_sim(a=x1, b=y1) cos1 = cos_sim(a=x1, b=y1)
cos3 = cos_sim(a=x1, b=y2, size=3) cos3 = cos_sim(a=x1, b=y2, size=3)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册