未验证 提交 bdcc2ad4 编写于 作者: B baoachun 提交者: GitHub

fix interpolate mkldnn op error (#36662)

上级 5f1b193a
...@@ -104,10 +104,12 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> { ...@@ -104,10 +104,12 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> {
scale.push_back(scale[0]); scale.push_back(scale[0]);
} else { // v2 } else { // v2
std::vector<float> scale_attr = ctx.Attr<std::vector<float>>("scale"); std::vector<float> scale_attr = ctx.Attr<std::vector<float>>("scale");
if (scale_attr.size() > 0) {
scale.resize(3, scale_attr[0]); scale.resize(3, scale_attr[0]);
std::copy(scale_attr.begin(), scale_attr.end(), scale.begin()); std::copy(scale_attr.begin(), scale_attr.end(), scale.begin());
} }
} }
}
if (scale[0] > 0.0f && scale[1] > 0.0f && scale[2] > 0.0f) { if (scale[0] > 0.0f && scale[1] > 0.0f && scale[2] > 0.0f) {
int j = 0; int j = 0;
std::vector<int64_t> in_dhw_vec = framework::vectorize(in_dhw_dims); std::vector<int64_t> in_dhw_vec = framework::vectorize(in_dhw_dims);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册