提交 09993479 编写于 作者: G guosheng

Fix python wrapper for layer_norm

上级 d63b7c60
......@@ -116,8 +116,6 @@ class LayerNormGradOp : public framework::OperatorWithKernel {
// check input
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of LayerNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Scale"),
"Input(Scale) of LayerNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Mean"),
"Input(Mean) of LayerNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Variance"),
......
......@@ -1637,7 +1637,7 @@ def layer_norm(input,
dtype=dtype,
default_initializer=Constant(1.0))
inputs['Scale'] = scale
if center:
if shift:
assert bias_attr is not False
bias = helper.create_parameter(
attr=helper.bias_attr, shape=param_shape, dtype=dtype, is_bias=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册