未验证 提交 5c4f6464 编写于 作者: J Jason 提交者: GitHub

Merge pull request #806 from wjj19950828/rm_fluid

Clean fluid
......@@ -34,7 +34,7 @@ class DetectionOutput(object):
pbv = priorbox_list[1]
pb = paddle.reshape(x=pb, shape=[-1, 4])
pbv = paddle.reshape(x=pbv, shape=[-1, 4])
pb_dim = fluid.layers.shape(pb)[0]
pb_dim = paddle.shape(pb)[0]
loc = paddle.reshape(x0, shape=[-1, pb_dim, 4])
conf_flatten = paddle.reshape(x1, shape=[0, pb_dim, -1])
out = fluid.layers.detection_output(
......
......@@ -13,7 +13,6 @@
# limitations under the License.
import paddle
import paddle.fluid as fluid
class Normalize(object):
......@@ -21,7 +20,7 @@ class Normalize(object):
self.axis = axis
def __call__(self, x, param):
l2_norm = fluid.layers.l2_normalize(x=x, axis=1)
l2_norm = paddle.norm(x=x, p=2, axis=1, keepdim=True)
param = paddle.reshape(param, [param.shape[-1]])
perm = list(range(len(l2_norm.shape)))
perm.pop(self.axis)
......
......@@ -13,7 +13,87 @@
# limitations under the License.
import paddle
import paddle.fluid as fluid
from paddle import _C_ops
from paddle import in_dynamic_mode
from paddle.common_ops_import import Variable, LayerHelper, check_variable_and_dtype, check_type, check_dtype
@paddle.jit.not_to_static
def prior_box(input,
image,
min_sizes,
max_sizes=None,
aspect_ratios=[1.],
variance=[0.1, 0.1, 0.2, 0.2],
flip=False,
clip=False,
steps=[0.0, 0.0],
offset=0.5,
min_max_aspect_ratios_order=False,
name=None):
helper = LayerHelper("prior_box", **locals())
dtype = helper.input_dtype()
check_variable_and_dtype(
input, 'input', ['uint8', 'int8', 'float32', 'float64'], 'prior_box')
def _is_list_or_tuple_(data):
return (isinstance(data, list) or isinstance(data, tuple))
if not _is_list_or_tuple_(min_sizes):
min_sizes = [min_sizes]
if not _is_list_or_tuple_(aspect_ratios):
aspect_ratios = [aspect_ratios]
if not (_is_list_or_tuple_(steps) and len(steps) == 2):
raise ValueError('steps should be a list or tuple ',
'with length 2, (step_width, step_height).')
min_sizes = list(map(float, min_sizes))
aspect_ratios = list(map(float, aspect_ratios))
steps = list(map(float, steps))
cur_max_sizes = None
if max_sizes is not None and len(max_sizes) > 0 and max_sizes[0] > 0:
if not _is_list_or_tuple_(max_sizes):
max_sizes = [max_sizes]
cur_max_sizes = max_sizes
if in_dynamic_mode():
attrs = ('min_sizes', min_sizes, 'aspect_ratios', aspect_ratios,
'variances', variance, 'flip', flip, 'clip', clip, 'step_w',
steps[0], 'step_h', steps[1], 'offset', offset,
'min_max_aspect_ratios_order', min_max_aspect_ratios_order)
if cur_max_sizes is not None:
attrs += ('max_sizes', cur_max_sizes)
box, var = _C_ops.prior_box(input, image, *attrs)
return box, var
else:
attrs = {
'min_sizes': min_sizes,
'aspect_ratios': aspect_ratios,
'variances': variance,
'flip': flip,
'clip': clip,
'step_w': steps[0],
'step_h': steps[1],
'offset': offset,
'min_max_aspect_ratios_order': min_max_aspect_ratios_order
}
if cur_max_sizes is not None:
attrs['max_sizes'] = cur_max_sizes
box = helper.create_variable_for_type_inference(dtype)
var = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type="prior_box",
inputs={"Input": input,
"Image": image},
outputs={"Boxes": box,
"Variances": var},
attrs=attrs, )
box.stop_gradient = True
var.stop_gradient = True
return box, var
class PriorBox(object):
......@@ -32,8 +112,7 @@ class PriorBox(object):
}
def __call__(self, x0, x1):
box, var = fluid.layers.prior_box(
input=x0, image=x1, **self.priorbox_layer_attrs)
box, var = prior_box(input=x0, image=x1, **self.priorbox_layer_attrs)
box = paddle.reshape(x=box, shape=[1, 1, -1])
var = paddle.reshape(x=var, shape=[1, 1, -1])
out = paddle.concat(x=[box, var], axis=1)
......
......@@ -13,7 +13,51 @@
# limitations under the License.
import paddle
import paddle.fluid as fluid
from paddle import _C_ops
from paddle import in_dynamic_mode
from paddle.common_ops_import import Variable, LayerHelper, check_variable_and_dtype, check_type, check_dtype
@paddle.jit.not_to_static
def roi_pool(input,
rois,
pooled_height,
pooled_width,
spatial_scale=1.0,
rois_num=None,
name=None):
if in_dynamic_mode():
assert rois_num is not None, "rois_num should not be None in dygraph mode."
pool_out, argmaxes = _C_ops.roi_pool(
input, rois, rois_num, "pooled_height", pooled_height,
"pooled_width", pooled_width, "spatial_scale", spatial_scale)
return pool_out, argmaxes
else:
check_variable_and_dtype(input, 'input', ['float32'], 'roi_pool')
check_variable_and_dtype(rois, 'rois', ['float32'], 'roi_pool')
helper = LayerHelper('roi_pool', **locals())
dtype = helper.input_dtype()
pool_out = helper.create_variable_for_type_inference(dtype)
argmaxes = helper.create_variable_for_type_inference(dtype='int32')
inputs = {
"X": input,
"ROIs": rois,
}
if rois_num is not None:
inputs['RoisNum'] = rois_num
helper.append_op(
type="roi_pool",
inputs=inputs,
outputs={"Out": pool_out,
"Argmax": argmaxes},
attrs={
"pooled_height": pooled_height,
"pooled_width": pooled_width,
"spatial_scale": spatial_scale
})
return pool_out, argmaxes
class ROIPooling(object):
......@@ -26,6 +70,5 @@ class ROIPooling(object):
def __call__(self, x0, x1):
slice_x1 = paddle.slice(input=x1, axes=[1], starts=[1], ends=[5])
out = fluid.layers.roi_pool(
input=x0, rois=slice_x1, **self.roipooling_layer_attrs)
out = roi_pool(input=x0, rois=slice_x1, **self.roipooling_layer_attrs)
return out
......@@ -13,7 +13,6 @@
# limitations under the License.
import paddle
import paddle.fluid as fluid
class Select(object):
......
......@@ -429,13 +429,13 @@ class CaffeOpMapper():
assert params.local_size % 2 == 1
alpha = params.alpha / float(params.local_size)
layer_attrs = {
"n": params.local_size,
"k": params.k,
"size": params.local_size,
"alpha": alpha,
"beta": params.beta,
"k": params.k,
}
self.paddle_graph.add_layer(
"paddle.fluid.layers.lrn",
"paddle.nn.LocalResponseNorm",
inputs={"input": input.name},
outputs=[node.layer_name],
**layer_attrs)
......@@ -1209,10 +1209,10 @@ class CaffeOpMapper():
input = self.graph.get_input_node(node, idx=0, copy=True)
params = node.layer.shuffle_channel_param
self.paddle_graph.add_layer(
"paddle.fluid.layers.shuffle_channel",
"paddle.nn.functional.channel_shuffle",
inputs={"x": input.name},
outputs=[node.layer_name],
group=params.group)
groups=params.group)
def Upsample(self, node):
assert len(
......
......@@ -18,3 +18,5 @@ from .pad_all_dim2 import PadAllDim2
from .pad_all_dim4 import PadAllDim4
from .pad_all_dim4_one_input import PadAllDim4WithOneInput
from .nms import NMS
from .roi_align import ROIAlign
from .roi_pooling import ROIPooling
......@@ -13,9 +13,9 @@
# limitations under the License.
import paddle
from paddle.fluid import core
from paddle.fluid.framework import Variable, in_dygraph_mode
from paddle.fluid.layer_helper import LayerHelper
from paddle import _C_ops
from paddle import in_dynamic_mode
from paddle.common_ops_import import Variable, LayerHelper
def multiclass_nms(bboxes,
......@@ -33,13 +33,13 @@ def multiclass_nms(bboxes,
name=None):
helper = LayerHelper('multiclass_nms3', **locals())
if in_dygraph_mode():
if in_dynamic_mode():
attrs = ('background_label', background_label, 'score_threshold',
score_threshold, 'nms_top_k', nms_top_k, 'nms_threshold',
nms_threshold, 'keep_top_k', keep_top_k, 'nms_eta', nms_eta,
'normalized', normalized)
output, index, nms_rois_num = core.ops.multiclass_nms3(bboxes, scores,
rois_num, *attrs)
output, index, nms_rois_num = _C_ops.multiclass_nms3(bboxes, scores,
rois_num, *attrs)
if not return_index:
index = None
return output, nms_rois_num, index
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import _C_ops
from paddle import in_dynamic_mode
from paddle.common_ops_import import Variable, LayerHelper, check_variable_and_dtype, check_type, check_dtype
@paddle.jit.not_to_static
def roi_align(input,
rois,
pooled_height,
pooled_width,
spatial_scale=1.0,
sampling_ratio=-1,
rois_num=None,
aligned=False,
name=None):
if in_dynamic_mode():
assert rois_num is not None, "rois_num should not be None in dygraph mode."
align_out = _C_ops.roi_align(
input, rois, rois_num, "pooled_height", pooled_height,
"pooled_width", pooled_width, "spatial_scale", spatial_scale,
"sampling_ratio", sampling_ratio, "aligned", aligned)
return align_out
else:
check_variable_and_dtype(input, 'input', ['float32', 'float64'],
'roi_align')
check_variable_and_dtype(rois, 'rois', ['float32', 'float64'],
'roi_align')
helper = LayerHelper('roi_align', **locals())
dtype = helper.input_dtype()
align_out = helper.create_variable_for_type_inference(dtype)
inputs = {
"X": input,
"ROIs": rois,
}
if rois_num is not None:
inputs['RoisNum'] = rois_num
helper.append_op(
type="roi_align",
inputs=inputs,
outputs={"Out": align_out},
attrs={
"pooled_height": pooled_height,
"pooled_width": pooled_width,
"spatial_scale": spatial_scale,
"sampling_ratio": sampling_ratio,
"aligned": aligned,
})
return align_out
class ROIAlign(object):
def __init__(self, pooled_height, pooled_width, spatial_scale,
sampling_ratio):
self.roialign_layer_attrs = {
"pooled_height": pooled_height,
"pooled_width": pooled_width,
"spatial_scale": spatial_scale,
'sampling_ratio': sampling_ratio,
}
def __call__(self, x0, x1, x2):
out = roi_align(
input=x0, rois=x1, rois_num=x2, **self.roialign_layer_attrs)
return out
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import _C_ops
from paddle import in_dynamic_mode
from paddle.common_ops_import import Variable, LayerHelper, check_variable_and_dtype, check_type, check_dtype
@paddle.jit.not_to_static
def roi_pool(input,
rois,
pooled_height,
pooled_width,
spatial_scale=1.0,
rois_num=None,
name=None):
if in_dynamic_mode():
assert rois_num is not None, "rois_num should not be None in dygraph mode."
pool_out, argmaxes = _C_ops.roi_pool(
input, rois, rois_num, "pooled_height", pooled_height,
"pooled_width", pooled_width, "spatial_scale", spatial_scale)
return pool_out, argmaxes
else:
check_variable_and_dtype(input, 'input', ['float32'], 'roi_pool')
check_variable_and_dtype(rois, 'rois', ['float32'], 'roi_pool')
helper = LayerHelper('roi_pool', **locals())
dtype = helper.input_dtype()
pool_out = helper.create_variable_for_type_inference(dtype)
argmaxes = helper.create_variable_for_type_inference(dtype='int32')
inputs = {
"X": input,
"ROIs": rois,
}
if rois_num is not None:
inputs['RoisNum'] = rois_num
helper.append_op(
type="roi_pool",
inputs=inputs,
outputs={"Out": pool_out,
"Argmax": argmaxes},
attrs={
"pooled_height": pooled_height,
"pooled_width": pooled_width,
"spatial_scale": spatial_scale
})
return pool_out, argmaxes
class ROIPooling(object):
def __init__(self, pooled_height, pooled_width, spatial_scale):
self.roipooling_layer_attrs = {
"pooled_height": pooled_height,
"pooled_width": pooled_width,
"spatial_scale": spatial_scale
}
def __call__(self, x0, x1):
out = roi_pool(input=x0, rois=x1, **self.roipooling_layer_attrs)
return out
......@@ -538,12 +538,14 @@ class OpSet9():
'pooled_width': pooled_width,
'spatial_scale': spatial_scale,
'sampling_ratio': sampling_ratio,
'rois_num': val_rois_num,
}
self.paddle_graph.add_layer(
'paddle.fluid.layers.roi_align',
inputs={'input': val_x.name,
'rois': val_rois.name},
'custom_layer:ROIAlign',
inputs={
'input': val_x.name,
'rois': val_rois.name,
'rois_num': val_rois_num
},
outputs=[node.name],
**layer_attrs)
......@@ -560,7 +562,7 @@ class OpSet9():
'spatial_scale': spatial_scale,
}
self.paddle_graph.add_layer(
'paddle.fluid.layers.roi_pool',
'custom_layer:ROIPooling',
inputs={'input': val_x.name,
'rois': val_rois.name},
outputs=[node.name],
......
......@@ -612,7 +612,7 @@ def prim_shape_dim(layer,
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = fluid.layers.shape({})[{}]".format(
line = "{} = paddle.shape({})[{}]".format(
layer.outputs[0],
get_value(layer, "input", different_attrs),
get_value(layer, "dim", different_attrs))
......
......@@ -6025,7 +6025,7 @@ def aten_upsample_bilinear2d(mapper, graph, node):
inputs={"input": inputs_name[1]},
outputs=[inputs_name[1] + "_isinstance"],
scope_name=scope_name,
cls="paddle.fluid.Variable")
cls="paddle.static.Variable")
# TODO(syf): paddle.Variable
graph.add_layer(
"prim.if", {"input": inputs_name[1] + "_isinstance"},
......@@ -6103,7 +6103,7 @@ def aten_upsample_nearest2d(mapper, graph, node):
inputs={"input": inputs_name[1]},
outputs=[inputs_name[1] + "_isinstance"],
scope_name=scope_name,
cls="paddle.fluid.Variable")
cls="paddle.static.Variable")
# TODO(syf): paddle.Variable
graph.add_layer(
"prim.if", {"input": inputs_name[1] + "_isinstance"},
......
......@@ -14,7 +14,7 @@
import paddle
from paddle.nn.functional import instance_norm
from paddle.fluid.initializer import Constant
from paddle.nn.initializer import Constant
class InstanceNorm(paddle.nn.Layer):
......
......@@ -46,7 +46,7 @@ class InterpolateBilinearFuser(FuseBase):
if x2271 :
x2274 = x2197[0]
x2275 = x2197[1]
x2233_isinstance = isinstance(x2233, paddle.fluid.Variable)
x2233_isinstance = isinstance(x2233, paddle.static.Variable)
if x2233_isinstance :
x2233 = x2233.numpy().tolist()
x2276 = paddle.nn.functional.interpolate(x=x2181, size=x2233, scale_factor=x2274, align_corners=False, align_mode=0, mode='bilinear')
......@@ -146,7 +146,7 @@ class InterpolateBilinearFuser(FuseBase):
"prim.isinstance",
inputs={"input": "interpolate-input-3"},
outputs=["interpolate-input-0_isinstance"],
cls="paddle.fluid.Variable")
cls="paddle.static.Variable")
pattern_block_block.add_layer(
"prim.if", {"input": "interpolate-input-0_isinstance"},
outputs=["interpolate-input-0_if1"])
......
......@@ -103,15 +103,7 @@ class PaddleDtypes():
self.t_int64 = paddle.int64
self.t_bool = paddle.bool
else:
self.t_float16 = "paddle.fluid.core.VarDesc.VarType.FP16"
self.t_float32 = "paddle.fluid.core.VarDesc.VarType.FP32"
self.t_float64 = "paddle.fluid.core.VarDesc.VarType.FP64"
self.t_uint8 = "paddle.fluid.core.VarDesc.VarType.UINT8"
self.t_int8 = "paddle.fluid.core.VarDesc.VarType.INT8"
self.t_int16 = "paddle.fluid.core.VarDesc.VarType.INT16"
self.t_int32 = "paddle.fluid.core.VarDesc.VarType.INT32"
self.t_int64 = "paddle.fluid.core.VarDesc.VarType.INT64"
self.t_bool = "paddle.fluid.core.VarDesc.VarType.BOOL"
raise Exception("Paddle>=2.0.0 is required, Please update version!")
is_new_version = check_version()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册