未验证 提交 1b0cecb7 编写于 作者: C Chen Weihang 提交者: GitHub

polish several details (#40485)

上级 f3f27d25
......@@ -249,13 +249,13 @@ class CompatMetaTensor : public phi::MetaTensor {
}
void share_meta(const MetaTensor& meta_tensor) override {
share_dims(meta_tensor);
set_dtype(meta_tensor.dtype());
// VarDesc doesn't contains layout, so we cannot share layout
// set_layout(meta_tensor.layout());
// special case 1: share lod of LoDTensor
// special case: share lod of LoDTensor
share_lod(meta_tensor);
share_dims(meta_tensor);
}
private:
......
......@@ -215,7 +215,7 @@ REGISTER_OPERATOR(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker,
ops::SoftmaxOpGradMaker<paddle::framework::OpDesc>,
ops::SoftmaxOpGradMaker<paddle::imperative::OpBase>,
ops::SoftmaxInplaceInferer, SoftmaxInferShapeFunctor);
DECLARE_INFER_SHAPE_FUNCTOR(softmax_grad, SoftmaxGradnferShapeFunctor,
DECLARE_INFER_SHAPE_FUNCTOR(softmax_grad, SoftmaxGradInferShapeFunctor,
PD_INFER_META(phi::GeneralUnaryGradInferMeta));
REGISTER_OPERATOR(softmax_grad, ops::SoftmaxOpGrad,
SoftmaxGradnferShapeFunctor);
SoftmaxGradInferShapeFunctor);
......@@ -110,7 +110,7 @@ void MetaTensor::share_meta(const MetaTensor& meta_tensor) {
}
}
TensorBase* MetaTensor::get_tensor() const { return tensor_; }
TensorBase* MetaTensor::tensor() const { return tensor_; }
void MetaTensor::share_dims(const MetaTensor& meta_tensor) {
bool is_dense_tensor = phi::DenseTensor::classof(tensor_);
......@@ -118,7 +118,7 @@ void MetaTensor::share_dims(const MetaTensor& meta_tensor) {
if (is_dense_tensor || is_selected_rows) {
set_dims(meta_tensor.dims());
if (is_selected_rows) {
const auto in_tensor_base = meta_tensor.get_tensor();
const auto in_tensor_base = meta_tensor.tensor();
PADDLE_ENFORCE_EQ(
phi::SelectedRows::classof(in_tensor_base),
true,
......
......@@ -66,7 +66,7 @@ class MetaTensor {
// Because the lod in compiletime and runtime is different,
// so `LoD` cannot in public methods
const LoD& lod() const;
TensorBase* get_tensor() const;
TensorBase* tensor() const;
TensorBase* tensor_;
};
......
......@@ -39,7 +39,7 @@ void ComputeInverseEigen(const Context& dev_ctx,
int batch_size = rank > 2 ? a.numel() / (n * n) : 1;
const T* a_ptr = a.data<T>();
T* a_inv_ptr = a_inv->mutable_data<T>(dev_ctx.GetPlace());
T* a_inv_ptr = dev_ctx.template Alloc<T>(a_inv);
for (int i = 0; i < batch_size; ++i) {
ConstEigenMatrixMap mat(a_ptr + i * n * n, n, n);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册