提交 6b9f1d34 编写于 作者: G guosheng

Make weight normalization adapt to the up-to-date code

上级 a422f346
...@@ -226,8 +226,8 @@ class LayerHelper(object): ...@@ -226,8 +226,8 @@ class LayerHelper(object):
scale = elementwise_div( scale = elementwise_div(
x=g, y=norm) # The shapes of g and norm are the same. x=g, y=norm) # The shapes of g and norm are the same.
# Currently, elementwise_mul only support broadcast when the shape # Currently, elementwise_mul only support broadcast when the shape
# of y is a subset of x. Thus, we should reshape y to squeeze to # of y is a subset of the shape of x. Thus, we reshape y to squeeze
# achive it. # to achive the subset.
w = elementwise_mul( w = elementwise_mul(
x=v, x=v,
y=scale if dim is None else reshape( y=scale if dim is None else reshape(
......
...@@ -15,7 +15,10 @@ ...@@ -15,7 +15,10 @@
from initializer import Initializer, Xavier, Constant from initializer import Initializer, Xavier, Constant
from regularizer import WeightDecayRegularizer from regularizer import WeightDecayRegularizer
__all__ = ['ParamAttr'] __all__ = [
'ParamAttr',
'WeightNormParamAttr',
]
class ParamAttr(object): class ParamAttr(object):
...@@ -92,7 +95,7 @@ class WeightNormParamAttr(ParamAttr): ...@@ -92,7 +95,7 @@ class WeightNormParamAttr(ParamAttr):
""" """
# List to record the parameters reparameterized by weight normalization. # List to record the parameters reparameterized by weight normalization.
# If these parameters are treated as Variable rather than Parameter, # If these parameters are treated as Variable rather than Parameter,
# it can be used to discriminate these parameters and help to serialize # it can be used to discriminate these parameters and help to serialize
# these paramters for inference. # these paramters for inference.
params_with_weight_norm = [] params_with_weight_norm = []
......
...@@ -52,7 +52,7 @@ class TestWeightNormalization(unittest.TestCase): ...@@ -52,7 +52,7 @@ class TestWeightNormalization(unittest.TestCase):
def run_program(self): def run_program(self):
outputs = [] outputs = []
places = [core.CPUPlace()] places = [core.CPUPlace()]
if core.is_compile_gpu(): if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0)) places.append(core.CUDAPlace(0))
for place in places: for place in places:
self.set_inputs(place) self.set_inputs(place)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册