提交 f87b9db5 编写于 作者: H hjchen2

Fix bugs

上级 87ac1969
......@@ -28,11 +28,11 @@ class AnchorGeneratorParam : public OpParam {
AnchorGeneratorParam(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
input_ = OpParam::GetVarValue<framework::Tensor>("Input", inputs, scope);
input_ = OpParam::GetVarValue<framework::LoDTensor>("Input", inputs, scope);
output_anchors_ =
OpParam::GetVarValue<framework::Tensor>("Anchors", outputs, scope);
OpParam::GetVarValue<framework::LoDTensor>("Anchors", outputs, scope);
output_variances_ =
OpParam::GetVarValue<framework::Tensor>("Variances", outputs, scope);
OpParam::GetVarValue<framework::LoDTensor>("Variances", outputs, scope);
anchor_sizes_ = OpParam::GetAttr<std::vector<float>>("anchor_sizes", attrs);
aspect_ratios_ =
......@@ -65,14 +65,16 @@ class ProposalParam : public OpParam {
public:
ProposalParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
scores_ = OpParam::GetVarValue<framework::Tensor>("Scores", inputs, scope);
scores_ =
OpParam::GetVarValue<framework::LoDTensor>("Scores", inputs, scope);
bbox_deltas_ =
OpParam::GetVarValue<framework::Tensor>("BboxDeltas", inputs, scope);
im_info_ = OpParam::GetVarValue<framework::Tensor>("ImInfo", inputs, scope);
OpParam::GetVarValue<framework::LoDTensor>("BboxDeltas", inputs, scope);
im_info_ =
OpParam::GetVarValue<framework::LoDTensor>("ImInfo", inputs, scope);
anchors_ =
OpParam::GetVarValue<framework::Tensor>("Anchors", inputs, scope);
OpParam::GetVarValue<framework::LoDTensor>("Anchors", inputs, scope);
variances_ =
OpParam::GetVarValue<framework::Tensor>("Variances", inputs, scope);
OpParam::GetVarValue<framework::LoDTensor>("Variances", inputs, scope);
rpn_rois_ =
OpParam::GetVarValue<framework::LoDTensor>("RpnRois", outputs, scope);
......@@ -112,10 +114,10 @@ class PSRoiPoolParam : public OpParam {
public:
PSRoiPoolParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
input_x_ = OpParam::GetVarValue<framework::Tensor>("X", inputs, scope);
input_x_ = OpParam::GetVarValue<framework::LoDTensor>("X", inputs, scope);
input_rois_ =
OpParam::GetVarValue<framework::LoDTensor>("ROIs", inputs, scope);
output_ = OpParam::GetVarValue<framework::Tensor>("Out", outputs, scope);
output_ = OpParam::GetVarValue<framework::LoDTensor>("Out", outputs, scope);
output_channels_ = OpParam::GetAttr<int>("output_channels", attrs);
pooled_height_ = OpParam::GetAttr<int>("pooled_height", attrs);
......@@ -143,10 +145,10 @@ class RoiPerspectiveParam : public OpParam {
RoiPerspectiveParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs,
const Scope &scope) {
input_x_ = OpParam::GetVarValue<framework::Tensor>("X", inputs, scope);
input_x_ = OpParam::GetVarValue<framework::LoDTensor>("X", inputs, scope);
input_rois_ =
OpParam::GetVarValue<framework::LoDTensor>("ROIs", inputs, scope);
output_ = OpParam::GetVarValue<framework::Tensor>("Out", outputs, scope);
output_ = OpParam::GetVarValue<framework::LoDTensor>("Out", outputs, scope);
spatial_scale_ = OpParam::GetAttr<float>("spatial_scale", attrs);
transformed_height_ = OpParam::GetAttr<int>("transformed_height", attrs);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册