未验证 提交 8df91763 编写于 作者: C Chen Weihang 提交者: GitHub

[Phi] Move mean op kernel into phi (#40872)

* add mean phi kernel

* remove original mean kernel

* add alias name
上级 6d3db9c7
...@@ -12,11 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,11 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/mean_op.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -94,21 +95,3 @@ REGISTER_OPERATOR(mean, ops::MeanOp, ops::MeanOpMaker, ops::MeanOpInferVarType, ...@@ -94,21 +95,3 @@ REGISTER_OPERATOR(mean, ops::MeanOp, ops::MeanOpMaker, ops::MeanOpInferVarType,
ops::MeanGradMaker<paddle::imperative::OpBase>); ops::MeanGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(mean_grad, ops::MeanGradOp, REGISTER_OPERATOR(mean_grad, ops::MeanGradOp,
ops::MeanGradNoNeedBufferVarsInferer); ops::MeanGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(
mean, ops::MeanKernel<paddle::platform::CPUDeviceContext, float>,
ops::MeanKernel<paddle::platform::CPUDeviceContext, double>,
ops::MeanKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>,
ops::MeanKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::MeanKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
mean_grad, ops::MeanGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::MeanGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::MeanGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>,
ops::MeanGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::MeanGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/kernel_primitives/functor_primitives.h"
#include "paddle/fluid/operators/mean_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
template <typename T>
__global__ void MeanRunKernel(const T* in_data, T* out_data, int N) {
using MT = typename details::MPTypeTrait<T>::Type;
int idx = blockDim.x * blockIdx.x + threadIdx.x;
auto data = static_cast<MT>(in_data[0]);
for (; idx < N; idx += blockDim.x * gridDim.x) {
out_data[idx] = static_cast<T>(data / (static_cast<MT>(N)));
}
}
template <typename DeviceContext, typename T>
class MeanCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* input = context.Input<Tensor>("X");
auto* output = context.Output<Tensor>("Out");
const T* in_data = input->data<T>();
T* out_data = output->mutable_data<T>(context.GetPlace());
auto numel = input->numel();
auto rank = input->dims().size();
auto place = context.GetPlace();
auto stream = context.cuda_device_context().stream();
if (rank == 0) { // scalar
auto gpu_place = place;
memory::Copy(gpu_place, out_data, gpu_place, in_data, numel * sizeof(T),
stream);
return;
}
using Div = kernel_primitives::DivideFunctor<T, T>;
std::vector<int> reduce_dims;
reduce_dims.reserve(rank);
for (decltype(rank) i = 0; i < rank; ++i) {
reduce_dims.push_back(i);
}
TensorReduceImpl<T, T, kernel_primitives::AddFunctor,
kps::IdentityFunctor<T>>(
context.cuda_device_context(), *input, output,
kps::IdentityFunctor<T>(), reduce_dims, stream, true);
}
};
template <typename DeviceContext, typename T>
class MeanCUDAGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto OG = context.Input<Tensor>(framework::GradVarName("Out"));
PADDLE_ENFORCE_EQ(OG->numel(), 1,
platform::errors::InvalidArgument(
"Mean Gradient Input Tensor len should be 1. But "
"received Out@Grad's elements num is %d.",
OG->numel()));
auto IG = context.Output<Tensor>(framework::GradVarName("X"));
IG->mutable_data<T>(context.GetPlace());
auto in_data = OG->data<T>();
auto size_prob = IG->numel();
auto out_data = IG->data<T>();
int threads = 512;
int grid = (size_prob + threads - 1) / threads;
auto stream = context.cuda_device_context().stream();
MeanRunKernel<T><<<grid, threads, 0, stream>>>(in_data, out_data,
size_prob);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
mean, ops::MeanCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::MeanCUDAKernel<paddle::platform::CUDADeviceContext, double>,
ops::MeanCUDAKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::MeanCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::MeanCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
mean_grad,
ops::MeanCUDAGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::MeanCUDAGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::MeanCUDAGradKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::MeanCUDAGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::MeanCUDAGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenScalar = framework::EigenScalar<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename DeviceContext, typename T>
class MeanKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* input = context.Input<Tensor>("X");
auto* output = context.Output<Tensor>("Out");
output->mutable_data<T>(context.GetPlace());
auto X = EigenVector<T>::Flatten(*input);
auto y = EigenScalar<T>::From(*output);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
y.device(place) = X.mean();
}
};
template <typename DeviceContext, typename T>
class MeanGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto OG = context.Input<Tensor>(framework::GradVarName("Out"));
PADDLE_ENFORCE_EQ(OG->numel(), 1UL,
platform::errors::InvalidArgument(
"Mean Gradient should be scalar. But received "
"Out@Grad's elements num is %d.",
OG->numel()));
auto IG = context.Output<Tensor>(framework::GradVarName("X"));
IG->mutable_data<T>(context.GetPlace());
T ig_size = static_cast<T>(IG->numel());
Eigen::DSizes<int, 1> bcast(static_cast<int>(ig_size));
EigenVector<T>::Flatten(*IG).device(
*context.template device_context<DeviceContext>().eigen_device()) =
(EigenVector<T>::From(*OG) / ig_size).broadcast(bcast);
}
};
} // namespace operators
} // namespace paddle
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/mean_op.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h" #include "paddle/fluid/operators/mlu/mlu_baseop.h"
#include "paddle/fluid/platform/device/mlu/device_context.h" #include "paddle/fluid/platform/device/mlu/device_context.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
...@@ -20,6 +20,8 @@ ...@@ -20,6 +20,8 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
template <typename T> template <typename T>
class MeanMLUKernel : public framework::OpKernel<T> { class MeanMLUKernel : public framework::OpKernel<T> {
public: public:
......
...@@ -9,13 +9,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -9,13 +9,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/mean_op.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class MeanNPUKernel : public framework::OpKernel<T> { class MeanNPUKernel : public framework::OpKernel<T> {
public: public:
......
...@@ -12,15 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,15 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/mean_op.h"
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class MeanXPUKernel : public framework::OpKernel<T> { class MeanXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type; using XPUType = typename XPUTypeTrait<T>::Type;
......
...@@ -16,7 +16,6 @@ limitations under the License. */ ...@@ -16,7 +16,6 @@ limitations under the License. */
#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/operators/fill_constant_op.h" #include "paddle/fluid/operators/fill_constant_op.h"
#include "paddle/fluid/operators/mean_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -50,8 +50,6 @@ const std::unordered_set<std::string> deprecated_op_names({"diag", ...@@ -50,8 +50,6 @@ const std::unordered_set<std::string> deprecated_op_names({"diag",
"matmul", "matmul",
"matmul_grad", "matmul_grad",
"matmul_grad_grad", "matmul_grad_grad",
"mean",
"mean_grad",
"max", "max",
"max_grad", "max_grad",
"min", "min",
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/mean_all_grad_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
namespace phi {
template <typename T, typename Context>
void MeanAllGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
PADDLE_ENFORCE_EQ(out_grad.numel(),
1UL,
phi::errors::InvalidArgument(
"Mean Gradient should be scalar. But received "
"Out@Grad's elements num is %d.",
out_grad.numel()));
dev_ctx.template Alloc<T>(x_grad);
T ig_size = static_cast<T>(x_grad->numel());
Eigen::DSizes<int, 1> bcast(static_cast<int>(ig_size));
EigenVector<T>::Flatten(*x_grad).device(*dev_ctx.eigen_device()) =
(EigenVector<T>::From(out_grad) / ig_size).broadcast(bcast);
}
} // namespace phi
PD_REGISTER_KERNEL(mean_all_grad,
CPU,
ALL_LAYOUT,
phi::MeanAllGradKernel,
float,
double,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/mean_all_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
namespace phi {
template <typename T, typename Context>
void MeanAllKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
auto X = EigenVector<T>::Flatten(x);
auto y = EigenScalar<T>::From(*out);
auto& place = *dev_ctx.eigen_device();
y.device(place) = X.mean();
}
} // namespace phi
PD_REGISTER_KERNEL(mean_all,
CPU,
ALL_LAYOUT,
phi::MeanAllKernel,
float,
double,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/mean_all_kernel.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T>
__global__ void MeanRunKernel(const T* in_data, T* out_data, int N) {
using MT = typename dtype::MPTypeTrait<T>::Type;
int idx = blockDim.x * blockIdx.x + threadIdx.x;
auto data = static_cast<MT>(in_data[0]);
for (; idx < N; idx += blockDim.x * gridDim.x) {
out_data[idx] = static_cast<T>(data / (static_cast<MT>(N)));
}
}
template <typename T, typename Context>
void MeanAllGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
PADDLE_ENFORCE_EQ(out_grad.numel(),
1,
phi::errors::InvalidArgument(
"Mean Gradient Input Tensor len should be 1. But "
"received Out@Grad's elements num is %d.",
out_grad.numel()));
dev_ctx.template Alloc<T>(x_grad);
auto in_data = out_grad.data<T>();
auto size_prob = x_grad->numel();
auto out_data = x_grad->data<T>();
int threads = 512;
int grid = (size_prob + threads - 1) / threads;
auto stream = dev_ctx.stream();
MeanRunKernel<T><<<grid, threads, 0, stream>>>(in_data, out_data, size_prob);
}
} // namespace phi
PD_REGISTER_KERNEL(mean_all_grad,
GPU,
ALL_LAYOUT,
phi::MeanAllGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/mean_all_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"
#include "paddle/phi/kernels/primitive/functor_primitives.h"
#include "paddle/fluid/memory/memcpy.h"
namespace phi {
template <typename T, typename Context>
void MeanAllKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
const T* in_data = x.data<T>();
T* out_data = dev_ctx.template Alloc<T>(out);
auto numel = x.numel();
auto rank = x.dims().size();
auto place = dev_ctx.GetPlace();
auto stream = dev_ctx.stream();
if (rank == 0) { // scalar
paddle::memory::Copy(
place, out_data, place, in_data, numel * sizeof(T), stream);
return;
}
std::vector<int> reduce_dims;
reduce_dims.reserve(rank);
for (decltype(rank) i = 0; i < rank; ++i) {
reduce_dims.push_back(i);
}
funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
dev_ctx,
x,
out,
kps::IdentityFunctor<T>(),
reduce_dims,
/*is_mean=*/true);
}
} // namespace phi
PD_REGISTER_KERNEL(mean_all,
GPU,
ALL_LAYOUT,
phi::MeanAllKernel,
float,
double,
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void MeanAllGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
// In order to be compatible with `mean` op in fluid,
// it is no longer used in 2.x API. It can not implement by call
// ReduceMeanKernel because ReduceMeanKernel doesn't support bfloat16 now,
// maybe we can unify this kernel to ReduceMeanKernel series in the future
template <typename T, typename Context>
void MeanAllKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature MeanOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("mean_all", {"X"}, {}, {"Out"});
}
KernelSignature MeanGradOpGradArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"mean_all_grad", {"X", GradVarName("Out")}, {}, {GradVarName("X")});
}
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(mean, mean_all);
PD_REGISTER_BASE_KERNEL_NAME(mean_grad, mean_all_grad);
PD_REGISTER_ARG_MAPPING_FN(mean, phi::MeanOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(mean_grad, phi::MeanGradOpGradArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册