未验证 提交 714b0076 编写于 作者: Z Zhang Ting 提交者: GitHub

Override GetKernelTypeForVar to avoid device transform, test=develop (#23032)

上级 112e3edb
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/shape_op.h"
#include <string>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
......@@ -30,6 +31,15 @@ class ShapeOp : public framework::OperatorWithKernel {
auto in_dim = ctx->GetInputDim("Input");
ctx->SetOutputDim("Out", {in_dim.size()});
}
protected:
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const framework::Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
return framework::OpKernelType(expected_kernel_type.data_type_,
expected_kernel_type.place_,
tensor.layout());
}
};
class ShapeOpMaker : public framework::OpProtoAndCheckerMaker {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册