提交 53781fc0 编写于 作者: Q Qiao Longfei

fix some bug

上级 3f91e0f0
...@@ -8099,12 +8099,12 @@ def bilinear_tensor_product(x, ...@@ -8099,12 +8099,12 @@ def bilinear_tensor_product(x,
position_tensor = fluid.layers.add_position_encoding(input=tensor) position_tensor = fluid.layers.add_position_encoding(input=tensor)
""" """
helper = LayerHelper('bilinear_tensor_product', **locals()) helper = LayerHelper('bilinear_tensor_product', **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype('x')
param_shape = [size, x.shape[1], y.shape[1]] param_shape = [size, x.shape[1], y.shape[1]]
w = helper.create_parameter( w = helper.create_parameter(
attr=param_attr, shape=param_shape, dtype=dtype, is_bias=False) attr=helper.param_attr, shape=param_shape, dtype=dtype, is_bias=False)
if name is None: if name is None:
out = helper.create_variable_for_type_inference(dtype=dtype) out = helper.create_variable_for_type_inference(dtype=dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册