未验证 提交 ab610a34 编写于 作者: A Adam 提交者: GitHub

transpose_mkldnn code change to meet Paddle standards (#22591)

上级 8f035fb6
......@@ -40,7 +40,8 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const T* input_data = input->data<T>();
if (ndims == 1) {
output->ShareDataWith(*input);
framework::TensorCopy(*input, input->place(), output);
output->set_format(input->format());
return;
}
......@@ -85,7 +86,8 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> reversed_axis(axis);
int ndims = axis.size();
if (ndims == 1) {
x_grad->ShareDataWith(*out_grad);
framework::TensorCopy(*out_grad, out_grad->place(), x_grad);
x_grad->set_format(out_grad->format());
return;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册