提交 09993479 编写于 作者: G guosheng

Fix python wrapper for layer_norm

上级 d63b7c60
...@@ -116,8 +116,6 @@ class LayerNormGradOp : public framework::OperatorWithKernel { ...@@ -116,8 +116,6 @@ class LayerNormGradOp : public framework::OperatorWithKernel {
// check input // check input
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of LayerNormOp should not be null."); "Input(X) of LayerNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Scale"),
"Input(Scale) of LayerNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Mean"), PADDLE_ENFORCE(ctx->HasInput("Mean"),
"Input(Mean) of LayerNormOp should not be null."); "Input(Mean) of LayerNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Variance"), PADDLE_ENFORCE(ctx->HasInput("Variance"),
......
...@@ -1637,7 +1637,7 @@ def layer_norm(input, ...@@ -1637,7 +1637,7 @@ def layer_norm(input,
dtype=dtype, dtype=dtype,
default_initializer=Constant(1.0)) default_initializer=Constant(1.0))
inputs['Scale'] = scale inputs['Scale'] = scale
if center: if shift:
assert bias_attr is not False assert bias_attr is not False
bias = helper.create_parameter( bias = helper.create_parameter(
attr=helper.bias_attr, shape=param_shape, dtype=dtype, is_bias=True) attr=helper.bias_attr, shape=param_shape, dtype=dtype, is_bias=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册