未验证 提交 1b0cecb7 编写于 作者: C Chen Weihang 提交者: GitHub

polish several details (#40485)

上级 f3f27d25
...@@ -249,13 +249,13 @@ class CompatMetaTensor : public phi::MetaTensor { ...@@ -249,13 +249,13 @@ class CompatMetaTensor : public phi::MetaTensor {
} }
void share_meta(const MetaTensor& meta_tensor) override { void share_meta(const MetaTensor& meta_tensor) override {
share_dims(meta_tensor);
set_dtype(meta_tensor.dtype()); set_dtype(meta_tensor.dtype());
// VarDesc doesn't contains layout, so we cannot share layout // VarDesc doesn't contains layout, so we cannot share layout
// set_layout(meta_tensor.layout()); // set_layout(meta_tensor.layout());
// special case 1: share lod of LoDTensor // special case: share lod of LoDTensor
share_lod(meta_tensor); share_lod(meta_tensor);
share_dims(meta_tensor);
} }
private: private:
......
...@@ -215,7 +215,7 @@ REGISTER_OPERATOR(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker, ...@@ -215,7 +215,7 @@ REGISTER_OPERATOR(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker,
ops::SoftmaxOpGradMaker<paddle::framework::OpDesc>, ops::SoftmaxOpGradMaker<paddle::framework::OpDesc>,
ops::SoftmaxOpGradMaker<paddle::imperative::OpBase>, ops::SoftmaxOpGradMaker<paddle::imperative::OpBase>,
ops::SoftmaxInplaceInferer, SoftmaxInferShapeFunctor); ops::SoftmaxInplaceInferer, SoftmaxInferShapeFunctor);
DECLARE_INFER_SHAPE_FUNCTOR(softmax_grad, SoftmaxGradnferShapeFunctor, DECLARE_INFER_SHAPE_FUNCTOR(softmax_grad, SoftmaxGradInferShapeFunctor,
PD_INFER_META(phi::GeneralUnaryGradInferMeta)); PD_INFER_META(phi::GeneralUnaryGradInferMeta));
REGISTER_OPERATOR(softmax_grad, ops::SoftmaxOpGrad, REGISTER_OPERATOR(softmax_grad, ops::SoftmaxOpGrad,
SoftmaxGradnferShapeFunctor); SoftmaxGradInferShapeFunctor);
...@@ -110,7 +110,7 @@ void MetaTensor::share_meta(const MetaTensor& meta_tensor) { ...@@ -110,7 +110,7 @@ void MetaTensor::share_meta(const MetaTensor& meta_tensor) {
} }
} }
TensorBase* MetaTensor::get_tensor() const { return tensor_; } TensorBase* MetaTensor::tensor() const { return tensor_; }
void MetaTensor::share_dims(const MetaTensor& meta_tensor) { void MetaTensor::share_dims(const MetaTensor& meta_tensor) {
bool is_dense_tensor = phi::DenseTensor::classof(tensor_); bool is_dense_tensor = phi::DenseTensor::classof(tensor_);
...@@ -118,7 +118,7 @@ void MetaTensor::share_dims(const MetaTensor& meta_tensor) { ...@@ -118,7 +118,7 @@ void MetaTensor::share_dims(const MetaTensor& meta_tensor) {
if (is_dense_tensor || is_selected_rows) { if (is_dense_tensor || is_selected_rows) {
set_dims(meta_tensor.dims()); set_dims(meta_tensor.dims());
if (is_selected_rows) { if (is_selected_rows) {
const auto in_tensor_base = meta_tensor.get_tensor(); const auto in_tensor_base = meta_tensor.tensor();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
phi::SelectedRows::classof(in_tensor_base), phi::SelectedRows::classof(in_tensor_base),
true, true,
......
...@@ -66,7 +66,7 @@ class MetaTensor { ...@@ -66,7 +66,7 @@ class MetaTensor {
// Because the lod in compiletime and runtime is different, // Because the lod in compiletime and runtime is different,
// so `LoD` cannot in public methods // so `LoD` cannot in public methods
const LoD& lod() const; const LoD& lod() const;
TensorBase* get_tensor() const; TensorBase* tensor() const;
TensorBase* tensor_; TensorBase* tensor_;
}; };
......
...@@ -39,7 +39,7 @@ void ComputeInverseEigen(const Context& dev_ctx, ...@@ -39,7 +39,7 @@ void ComputeInverseEigen(const Context& dev_ctx,
int batch_size = rank > 2 ? a.numel() / (n * n) : 1; int batch_size = rank > 2 ? a.numel() / (n * n) : 1;
const T* a_ptr = a.data<T>(); const T* a_ptr = a.data<T>();
T* a_inv_ptr = a_inv->mutable_data<T>(dev_ctx.GetPlace()); T* a_inv_ptr = dev_ctx.template Alloc<T>(a_inv);
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
ConstEigenMatrixMap mat(a_ptr + i * n * n, n, n); ConstEigenMatrixMap mat(a_ptr + i * n * n, n, n);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册