未验证 提交 48b4366c 编写于 作者: Y Yang 提交者: GitHub

[Phi] move ops: maxout/take_along_axis/put_along_axis (#39959)

* [Phi] move put_along_axis/take_along_axis/maxout

* use phi::Copy
上级 00566ead
......@@ -14,106 +14,107 @@ limitations under the License. */
#include "paddle/fluid/operators/math/maxouting.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
namespace paddle {
namespace operators {
namespace math {
// All tensors are in NCHW or NHWC format, and the groups must be greater than 1
template <typename T>
class MaxOutFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& input, framework::Tensor* output,
const int groups, const int axis) {
const int batch_size = input.dims()[0];
const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]);
const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]);
const int output_channels = output->dims()[axis];
int fea_size = input_height * input_width;
// c_size means the output size of each sample
int c_size = fea_size * output_channels;
const T* input_data = input.data<T>();
T* output_data = output->mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; ++i) {
int new_bindex = c_size * i;
for (int c = 0; c < output_channels; ++c) {
int new_cindex = fea_size * c;
for (int f = 0; f < fea_size; ++f) {
T ele = static_cast<T>(-FLT_MAX);
int input_idx, output_idx;
for (int ph = 0; ph < groups; ++ph) {
if (axis == 1) {
input_idx =
(new_bindex + new_cindex) * groups + ph * fea_size + f;
} else {
input_idx = (new_bindex + f * output_channels + c) * groups + ph;
}
T x = input_data[input_idx];
ele = ele > x ? ele : x;
}
template <typename DeviceContext, typename T>
void MaxOutFunctor<DeviceContext, T>::operator()(const DeviceContext& context,
const framework::Tensor& input,
framework::Tensor* output,
const int groups,
const int axis) {
const int batch_size = input.dims()[0];
const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]);
const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]);
const int output_channels = output->dims()[axis];
int fea_size = input_height * input_width;
// c_size means the output size of each sample
int c_size = fea_size * output_channels;
const T* input_data = input.data<T>();
T* output_data = output->mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; ++i) {
int new_bindex = c_size * i;
for (int c = 0; c < output_channels; ++c) {
int new_cindex = fea_size * c;
for (int f = 0; f < fea_size; ++f) {
T ele = static_cast<T>(-FLT_MAX);
int input_idx, output_idx;
for (int ph = 0; ph < groups; ++ph) {
if (axis == 1) {
output_idx = new_bindex + new_cindex + f;
input_idx = (new_bindex + new_cindex) * groups + ph * fea_size + f;
} else {
output_idx = new_bindex + f * output_channels + c;
input_idx = (new_bindex + f * output_channels + c) * groups + ph;
}
output_data[output_idx] = ele;
T x = input_data[input_idx];
ele = ele > x ? ele : x;
}
if (axis == 1) {
output_idx = new_bindex + new_cindex + f;
} else {
output_idx = new_bindex + f * output_channels + c;
}
output_data[output_idx] = ele;
}
}
}
};
}
template <class T>
class MaxOutGradFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& input, framework::Tensor* input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad, const int groups,
const int axis) {
const int batch_size = input.dims()[0];
const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]);
const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]);
const int output_channels = output.dims()[axis];
int fea_size = input_height * input_width;
const T* input_data = input.data<T>();
const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
template <typename DeviceContext, typename T>
void MaxOutGradFunctor<DeviceContext, T>::operator()(
const DeviceContext& context, const framework::Tensor& input,
framework::Tensor* input_grad, const framework::Tensor& output,
const framework::Tensor& output_grad, const int groups, const int axis) {
const int batch_size = input.dims()[0];
const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]);
const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]);
const int output_channels = output.dims()[axis];
int fea_size = input_height * input_width;
const T* input_data = input.data<T>();
const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; ++i) {
int blen = fea_size * output_channels * i;
for (int c = 0; c < output_channels; ++c) {
int clen = fea_size * c;
for (int f = 0; f < fea_size; ++f) {
int input_idx0, output_idx;
bool continue_match = true;
if (axis == 1) {
input_idx0 = (blen + clen) * groups + f;
output_idx = blen + clen + f;
} else {
input_idx0 = (blen + f * output_channels + c) * groups;
output_idx = blen + f * output_channels + c;
}
for (int g = 0; g < groups && continue_match; ++g) {
int idx_offset = (axis == 1 ? fea_size * g : g);
int input_idx = input_idx0 + idx_offset;
if (input_data[input_idx] == output_data[output_idx]) {
input_grad_data[input_idx] += output_grad_data[output_idx];
continue_match = false;
}
for (int i = 0; i < batch_size; ++i) {
int blen = fea_size * output_channels * i;
for (int c = 0; c < output_channels; ++c) {
int clen = fea_size * c;
for (int f = 0; f < fea_size; ++f) {
int input_idx0, output_idx;
bool continue_match = true;
if (axis == 1) {
input_idx0 = (blen + clen) * groups + f;
output_idx = blen + clen + f;
} else {
input_idx0 = (blen + f * output_channels + c) * groups;
output_idx = blen + f * output_channels + c;
}
for (int g = 0; g < groups && continue_match; ++g) {
int idx_offset = (axis == 1 ? fea_size * g : g);
int input_idx = input_idx0 + idx_offset;
if (input_data[input_idx] == output_data[output_idx]) {
input_grad_data[input_idx] += output_grad_data[output_idx];
continue_match = false;
}
}
}
}
}
};
}
template class MaxOutGradFunctor<platform::CPUDeviceContext, float>;
template class MaxOutGradFunctor<platform::CPUDeviceContext, double>;
template class MaxOutFunctor<platform::CPUDeviceContext, float>;
template class MaxOutFunctor<platform::CPUDeviceContext, double>;
template class MaxOutGradFunctor<phi::CPUContext, float>;
template class MaxOutGradFunctor<phi::CPUContext, double>;
template class MaxOutFunctor<phi::CPUContext, float>;
template class MaxOutFunctor<phi::CPUContext, double>;
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/maxouting.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
namespace paddle {
namespace operators {
......@@ -95,61 +96,57 @@ __global__ void KernelMaxoutGrad(const int nthreads, const T* input_data,
/*
* All tensors are in NCHW or NHWC format.
*/
template <typename T>
class MaxOutFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input, framework::Tensor* output,
const int groups, const int axis) {
const int batch_size = input.dims()[0];
const int input_channels = input.dims()[axis];
const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]);
const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]);
const int output_channels = output->dims()[axis];
const T* input_data = input.data<T>();
T* output_data = output->mutable_data<T>(context.GetPlace());
int nthreads = output->numel();
int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
KernelMaxOut<T><<<grid, threads, 0, context.stream()>>>(
nthreads, input_data, input_channels, input_height, input_width, groups,
axis, output_data);
}
};
template <typename DeviceContext, typename T>
void MaxOutFunctor<DeviceContext, T>::operator()(const DeviceContext& context,
const framework::Tensor& input,
framework::Tensor* output,
const int groups,
const int axis) {
const int batch_size = input.dims()[0];
const int input_channels = input.dims()[axis];
const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]);
const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]);
const int output_channels = output->dims()[axis];
const T* input_data = input.data<T>();
T* output_data = output->mutable_data<T>(context.GetPlace());
int nthreads = output->numel();
int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
KernelMaxOut<T><<<grid, threads, 0, context.stream()>>>(
nthreads, input_data, input_channels, input_height, input_width, groups,
axis, output_data);
}
/*
* All tensors are in NCHW or NHWC format.
*/
template <typename T>
class MaxOutGradFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input, framework::Tensor* input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad, const int groups,
const int axis) {
const int batch_size = input.dims()[0];
const int input_channels = input.dims()[axis];
const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]);
const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]);
const int output_channels = output.dims()[axis];
const T* input_data = input.data<T>();
const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
int nthreads = output.numel();
int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
KernelMaxoutGrad<T><<<grid, threads, 0, context.stream()>>>(
nthreads, input_data, output_data, output_grad_data, input_grad_data,
input_channels, input_height, input_width, groups, axis);
}
};
template <typename DeviceContext, typename T>
void MaxOutGradFunctor<DeviceContext, T>::operator()(
const DeviceContext& context, const framework::Tensor& input,
framework::Tensor* input_grad, const framework::Tensor& output,
const framework::Tensor& output_grad, const int groups, const int axis) {
const int batch_size = input.dims()[0];
const int input_channels = input.dims()[axis];
const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]);
const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]);
const int output_channels = output.dims()[axis];
const T* input_data = input.data<T>();
const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
int nthreads = output.numel();
int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
KernelMaxoutGrad<T><<<grid, threads, 0, context.stream()>>>(
nthreads, input_data, output_data, output_grad_data, input_grad_data,
input_channels, input_height, input_width, groups, axis);
}
template class MaxOutGradFunctor<platform::CUDADeviceContext, float>;
template class MaxOutGradFunctor<platform::CUDADeviceContext, double>;
......@@ -157,6 +154,12 @@ template class MaxOutGradFunctor<platform::CUDADeviceContext, double>;
template class MaxOutFunctor<platform::CUDADeviceContext, float>;
template class MaxOutFunctor<platform::CUDADeviceContext, double>;
template class MaxOutGradFunctor<phi::GPUContext, float>;
template class MaxOutGradFunctor<phi::GPUContext, double>;
template class MaxOutFunctor<phi::GPUContext, float>;
template class MaxOutFunctor<phi::GPUContext, double>;
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -30,7 +30,7 @@ class MaxOutFunctor {
const int axis = 1);
};
template <typename DeviceContext, class T>
template <typename DeviceContext, typename T>
class MaxOutGradFunctor {
public:
void operator()(const DeviceContext& context, const framework::Tensor& input,
......
......@@ -12,14 +12,14 @@
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/maxout_op.h"
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace operators {
using framework::Tensor;
class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
......@@ -130,10 +130,3 @@ REGISTER_OPERATOR(
paddle::framework::DefaultGradOpMaker<paddle::framework::OpDesc, true>,
paddle::framework::DefaultGradOpMaker<paddle::imperative::OpBase, true>);
REGISTER_OPERATOR(maxout_grad, ops::MaxOutOpGrad);
REGISTER_OP_CPU_KERNEL(
maxout, ops::MaxOutKernel<paddle::platform::CPUDeviceContext, float>,
ops::MaxOutKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
maxout_grad,
ops::MaxOutGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::MaxOutGradKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/maxout_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
maxout, ops::MaxOutKernel<paddle::platform::CUDADeviceContext, float>,
ops::MaxOutKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
maxout_grad,
ops::MaxOutGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::MaxOutGradKernel<paddle::platform::CUDADeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/maxouting.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class MaxOutKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* in_x = context.Input<Tensor>("X");
Tensor* out = context.Output<Tensor>("Out");
int groups = context.template Attr<int>("groups");
int axis = context.template Attr<int>("axis");
if (axis < 0) {
axis += in_x->dims().size();
}
math::MaxOutFunctor<DeviceContext, T> maxout_forward;
maxout_forward(context.template device_context<DeviceContext>(), *in_x, out,
groups, axis);
}
};
template <typename DeviceContext, typename T>
class MaxOutGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* in_x = context.Input<Tensor>("X");
const Tensor* out = context.Input<Tensor>("Out");
const Tensor* out_grad =
context.Input<Tensor>(framework::GradVarName("Out"));
Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X"));
int groups = context.template Attr<int>("groups");
int axis = context.template Attr<int>("axis");
if (axis < 0) {
axis += in_x->dims().size();
}
auto& device_ctx = context.template device_context<DeviceContext>();
phi::funcs::SetConstant<DeviceContext, T> zero;
if (in_x_grad) {
in_x_grad->mutable_data<T>(context.GetPlace());
zero(device_ctx, in_x_grad, static_cast<T>(0.0));
math::MaxOutGradFunctor<DeviceContext, T> maxout_backward;
maxout_backward(device_ctx, *in_x, in_x_grad, *out, *out_grad, groups,
axis);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/put_along_axis_op.h"
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/ddim.h"
......@@ -123,16 +124,3 @@ REGISTER_OPERATOR(put_along_axis, ops::PutAlongAxisOp, ops::PutAlongAxisOpMaker,
paddle::operators::PutAlongAxisInplaceInferer);
REGISTER_OPERATOR(put_along_axis_grad, ops::PutAlongAxisGradOp);
REGISTER_OP_CPU_KERNEL(put_along_axis, ops::PutAlongAxisOpKernel<float>,
ops::PutAlongAxisOpKernel<double>,
ops::PutAlongAxisOpKernel<int>,
ops::PutAlongAxisOpKernel<uint8_t>,
ops::PutAlongAxisOpKernel<int64_t>);
REGISTER_OP_CPU_KERNEL(put_along_axis_grad,
ops::PutAlongAxisGradOpKernel<float>,
ops::PutAlongAxisGradOpKernel<double>,
ops::PutAlongAxisGradOpKernel<int>,
ops::PutAlongAxisGradOpKernel<uint8_t>,
ops::PutAlongAxisGradOpKernel<int64_t>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/put_along_axis_op.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
template <typename T>
class PutAlongAxisCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet(
"PutAlongAxisCUDAKernel only runs on GPU device."));
auto input = ctx.Input<Tensor>("Input");
auto axis = ctx.Attr<int>("Axis");
auto value = ctx.Input<Tensor>("Value");
auto index = ctx.Input<Tensor>("Index");
auto reduce_op = ctx.Attr<std::string>("Reduce");
auto result = ctx.Output<Tensor>("Result");
const platform::DeviceContext &device_ctx = ctx.device_context();
const auto &index_type = framework::TransToProtoVarType(index->dtype());
framework::TensorCopy(*input, ctx.GetPlace(), result);
if (reduce_op == "add") {
if (index_type == framework::proto::VarType::INT32) {
gpu_scatter_add_kernel<T, int32_t>(*result, axis, *index, *value,
device_ctx);
} else if (index_type == framework::proto::VarType::INT64) {
gpu_scatter_add_kernel<T, int64_t>(*result, axis, *index, *value,
device_ctx);
}
} else if (reduce_op == "multiply" || reduce_op == "mul") {
if (index_type == framework::proto::VarType::INT32) {
gpu_scatter_mul_kernel<T, int32_t>(*result, axis, *index, *value,
device_ctx);
} else if (index_type == framework::proto::VarType::INT64) {
gpu_scatter_mul_kernel<T, int64_t>(*result, axis, *index, *value,
device_ctx);
}
} else if (reduce_op == "assign") {
if (index_type == framework::proto::VarType::INT32) {
gpu_scatter_assign_kernel<T, int32_t>(*result, axis, *index, *value,
device_ctx);
} else if (index_type == framework::proto::VarType::INT64) {
gpu_scatter_assign_kernel<T, int64_t>(*result, axis, *index, *value,
device_ctx);
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"can not support reduce_op: '%s' for scatter kernel, only "
"support reduce op: 'add‘, 'assign', 'mul' and 'multiply', the "
"defalut reduce op is 'assign' ",
reduce_op));
return;
}
}
};
template <typename T>
class PutAlongAxisGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet(
"PutAlongAxisGradOpCUDAKernel only runs on GPU."));
auto input_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
auto value_grad = ctx.Output<Tensor>(framework::GradVarName("Value"));
auto index = ctx.Input<Tensor>("Index");
auto result_grad = ctx.Input<Tensor>(framework::GradVarName("Result"));
auto axis = ctx.Attr<int>("Axis");
const auto &index_type = framework::TransToProtoVarType(index->dtype());
if (input_grad) {
framework::TensorCopy(*result_grad, ctx.GetPlace(), input_grad);
if (index_type == framework::proto::VarType::INT32) {
gpu_scatter_input_grad_kernel<T, int32_t>(
*result_grad, axis, *index, *input_grad, ctx.device_context());
} else {
gpu_scatter_input_grad_kernel<T, int64_t>(
*result_grad, axis, *index, *input_grad, ctx.device_context());
}
}
if (value_grad) {
value_grad->Resize(index->dims());
value_grad->mutable_data<T>(ctx.GetPlace());
if (index_type == framework::proto::VarType::INT32) {
gpu_gather_kernel<T, int32_t>(
*result_grad, axis, *index, *value_grad,
ctx.device_context()); // the gradient of scatter is gather
} else if (index_type == framework::proto::VarType::INT64) {
gpu_gather_kernel<T, int64_t>(*result_grad, axis, *index, *value_grad,
ctx.device_context());
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(put_along_axis, ops::PutAlongAxisCUDAKernel<float>,
ops::PutAlongAxisCUDAKernel<double>,
ops::PutAlongAxisCUDAKernel<int64_t>,
ops::PutAlongAxisCUDAKernel<int>,
ops::PutAlongAxisCUDAKernel<plat::float16>);
REGISTER_OP_CUDA_KERNEL(put_along_axis_grad,
ops::PutAlongAxisGradOpCUDAKernel<float>,
ops::PutAlongAxisGradOpCUDAKernel<double>,
ops::PutAlongAxisGradOpCUDAKernel<int64_t>,
ops::PutAlongAxisGradOpCUDAKernel<int>,
ops::PutAlongAxisGradOpCUDAKernel<plat::float16>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class PutAlongAxisOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet(
"PutAlongAxisOpKernel only runs on CPU."));
auto input = ctx.Input<Tensor>("Input");
auto axis = ctx.Attr<int>("Axis");
auto value = ctx.Input<Tensor>("Value");
auto index = ctx.Input<Tensor>("Index");
auto reduce_op = ctx.Attr<std::string>("Reduce");
auto result = ctx.Output<Tensor>("Result");
framework::TensorCopy(*input, ctx.GetPlace(), result);
const platform::DeviceContext &device_ctx = ctx.device_context();
const auto &index_type = framework::TransToProtoVarType(index->dtype());
if (reduce_op == "add") {
if (index_type == framework::proto::VarType::INT32) {
cpu_scatter_add_kernel<T, int32_t>(*result, axis, *index, *value,
device_ctx);
} else if (index_type == framework::proto::VarType::INT64) {
cpu_scatter_add_kernel<T, int64_t>(*result, axis, *index, *value,
device_ctx);
}
} else if (reduce_op == "multiply" || reduce_op == "mul") {
if (index_type == framework::proto::VarType::INT32) {
cpu_scatter_mul_kernel<T, int32_t>(*result, axis, *index, *value,
device_ctx);
} else if (index_type == framework::proto::VarType::INT64) {
cpu_scatter_mul_kernel<T, int64_t>(*result, axis, *index, *value,
device_ctx);
}
} else if (reduce_op == "assign") {
if (index_type == framework::proto::VarType::INT32) {
cpu_scatter_assign_kernel<T, int32_t>(*result, axis, *index, *value,
device_ctx);
} else if (index_type == framework::proto::VarType::INT64) {
cpu_scatter_assign_kernel<T, int64_t>(*result, axis, *index, *value,
device_ctx);
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"can not support reduce_op: '%s' for scatter kernel, only "
"support reduce op: 'add‘, 'assign', 'mul' and 'multiply', the "
"defalut reduce "
"op is 'assign' ",
reduce_op));
return;
}
}
};
template <typename T>
class PutAlongAxisGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet(
"PutAlongAxisGradOpKernel only runs on CPU."));
auto input_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
auto value_grad = ctx.Output<Tensor>(framework::GradVarName("Value"));
auto index = ctx.Input<Tensor>("Index");
auto result_grad = ctx.Input<Tensor>(framework::GradVarName("Result"));
auto axis = ctx.Attr<int>("Axis");
const auto &index_type = framework::TransToProtoVarType(index->dtype());
if (input_grad) {
framework::TensorCopy(*result_grad, ctx.GetPlace(), input_grad);
if (index_type == framework::proto::VarType::INT32) {
cpu_scatter_input_grad_kernel<T, int32_t>(
// Here passing an unused argument *result_grad, because it's
// convenient to instantiate a bunch of template function with the
// same arguments list.
*result_grad, axis, *index, *input_grad, ctx.device_context());
} else {
cpu_scatter_input_grad_kernel<T, int64_t>(
*result_grad, axis, *index, *input_grad, ctx.device_context());
}
}
if (value_grad) {
value_grad->Resize(index->dims());
value_grad->mutable_data<T>(ctx.GetPlace());
if (index_type == framework::proto::VarType::INT32) {
cpu_gather_kernel<T, int32_t>(*result_grad, axis, *index, *value_grad,
ctx.device_context());
} else if (index_type == framework::proto::VarType::INT64) {
cpu_gather_kernel<T, int64_t>(*result_grad, axis, *index, *value_grad,
ctx.device_context());
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/take_along_axis_op.h"
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/ddim.h"
......@@ -139,16 +140,3 @@ REGISTER_OPERATOR(take_along_axis, ops::TakeAlongAxisOp,
ops::TakeAlongAxisGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(take_along_axis_grad, ops::TakeAlongAxisGradOp);
REGISTER_OP_CPU_KERNEL(take_along_axis, ops::TakeAlongAxisOpKernel<float>,
ops::TakeAlongAxisOpKernel<double>,
ops::TakeAlongAxisOpKernel<int>,
ops::TakeAlongAxisOpKernel<uint8_t>,
ops::TakeAlongAxisOpKernel<int64_t>);
REGISTER_OP_CPU_KERNEL(take_along_axis_grad,
ops::TakeAlongAxisGradOpKernel<float>,
ops::TakeAlongAxisGradOpKernel<double>,
ops::TakeAlongAxisGradOpKernel<int>,
ops::TakeAlongAxisGradOpKernel<uint8_t>,
ops::TakeAlongAxisGradOpKernel<int64_t>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/take_along_axis_op.h"
#include "paddle/phi/core/ddim.h"
namespace paddle {
namespace operators {
template <typename T>
class TakeAlongAxisCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet(
"This kernel only runs on GPU device."));
auto input = ctx.Input<Tensor>("Input");
auto axis = ctx.Attr<int>("Axis");
auto index = ctx.Input<Tensor>("Index");
auto result = ctx.Output<Tensor>("Result");
result->Resize(index->dims());
result->mutable_data<T>(ctx.GetPlace());
const auto &index_type = framework::TransToProtoVarType(index->dtype());
if (index_type == framework::proto::VarType::INT32) {
gpu_gather_kernel<T, int32_t>(*input, axis, *index, *result,
ctx.device_context());
} else if (index_type == framework::proto::VarType::INT64) {
gpu_gather_kernel<T, int64_t>(*input, axis, *index, *result,
ctx.device_context());
}
}
};
template <typename T>
class TakeAlongAxisGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet("This kernel only runs on GPU."));
auto input_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
auto index = ctx.Input<Tensor>("Index");
auto result_grad = ctx.Input<Tensor>(framework::GradVarName("Result"));
auto axis = ctx.Attr<int>("Axis");
// We need to know the shape of input matrix to determine the shape of grad
// matrix of input.
auto input = ctx.Input<Tensor>("Input");
input_grad->Resize(input->dims());
input_grad->mutable_data<T>(ctx.GetPlace());
// Set to zero tensor.
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
phi::funcs::SetConstant<platform::CUDADeviceContext, T> functor;
functor(reinterpret_cast<const platform::CUDADeviceContext &>(dev_ctx),
input_grad, static_cast<T>(0));
const auto &index_type = framework::TransToProtoVarType(index->dtype());
if (index_type == framework::proto::VarType::INT32) {
gpu_scatter_add_kernel<T, int32_t>(
*input_grad, axis, *index, *result_grad,
ctx.device_context()); // the gradient of gather is scatter
} else if (index_type == framework::proto::VarType::INT64) {
gpu_scatter_add_kernel<T, int64_t>(*input_grad, axis, *index,
*result_grad, ctx.device_context());
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(take_along_axis, ops::TakeAlongAxisCUDAKernel<float>,
ops::TakeAlongAxisCUDAKernel<double>,
ops::TakeAlongAxisCUDAKernel<int64_t>,
ops::TakeAlongAxisCUDAKernel<int>,
ops::TakeAlongAxisCUDAKernel<plat::float16>);
REGISTER_OP_CUDA_KERNEL(take_along_axis_grad,
ops::TakeAlongAxisGradOpCUDAKernel<float>,
ops::TakeAlongAxisGradOpCUDAKernel<double>,
ops::TakeAlongAxisGradOpCUDAKernel<int64_t>,
ops::TakeAlongAxisGradOpCUDAKernel<int>,
ops::TakeAlongAxisGradOpCUDAKernel<plat::float16>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class TakeAlongAxisOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(
platform::is_cpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet("This kernel only runs on CPU."));
auto input = ctx.Input<Tensor>("Input");
auto axis = ctx.Attr<int>("Axis");
auto index = ctx.Input<Tensor>("Index");
auto result = ctx.Output<Tensor>("Result");
result->Resize(index->dims());
result->mutable_data<T>(ctx.GetPlace());
const auto &index_type = framework::TransToProtoVarType(index->dtype());
if (index_type == framework::proto::VarType::INT32) {
cpu_gather_kernel<T, int32_t>(*input, axis, *index, *result,
ctx.device_context());
} else if (index_type == framework::proto::VarType::INT64) {
cpu_gather_kernel<T, int64_t>(*input, axis, *index, *result,
ctx.device_context());
}
}
};
template <typename T>
class TakeAlongAxisGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(
platform::is_cpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet("This kernel only runs on CPU."));
auto input_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
auto index = ctx.Input<Tensor>("Index");
auto result_grad = ctx.Input<Tensor>(framework::GradVarName("Result"));
auto axis = ctx.Attr<int>("Axis");
// We need to know the shape of input matrix to determine the shape of grad
// matrix of input.
auto input = ctx.Input<Tensor>("Input");
input_grad->Resize(input->dims());
input_grad->mutable_data<T>(ctx.GetPlace());
// Set to zero tensor.
auto &dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
phi::funcs::SetConstant<platform::CPUDeviceContext, T> functor;
functor(reinterpret_cast<const platform::CPUDeviceContext &>(dev_ctx),
input_grad, static_cast<T>(0));
const auto &index_type = framework::TransToProtoVarType(index->dtype());
if (index_type == framework::proto::VarType::INT32) {
cpu_scatter_add_kernel<T, int32_t>(
*input_grad, axis, *index, *result_grad,
ctx.device_context()); // the gradient of gather is scatter
} else if (index_type == framework::proto::VarType::INT64) {
cpu_scatter_add_kernel<T, int64_t>(*input_grad, axis, *index,
*result_grad, ctx.device_context());
}
}
};
} // namespace operators
} // namespace paddle
......@@ -27,11 +27,17 @@ kernel_library(full_kernel DEPS ${COMMON_KERNEL_DEPS} empty_kernel)
# Some kernels depend on some targets that are not commonly used.
# These targets are not suitable for common dependencies.
# In this case, you need to manually generate them here.
set(MANUAL_BUILD_KERNELS math_kernel softmax_kernel softmax_grad_kernel triangular_solve_grad_kernel)
set(MANUAL_BUILD_KERNELS math_kernel softmax_kernel softmax_grad_kernel triangular_solve_grad_kernel maxout_kernel maxout_grad_kernel put_along_axis_kernel put_along_axis_grad_kernel take_along_axis_kernel take_along_axis_grad_kernel)
kernel_library(math_kernel DEPS ${COMMON_KERNEL_DEPS} cast_kernel copy_kernel)
kernel_library(softmax_kernel DEPS ${COMMON_KERNEL_DEPS} softmax)
kernel_library(softmax_grad_kernel DEPS ${COMMON_KERNEL_DEPS} softmax)
kernel_library(triangular_solve_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_reduce)
kernel_library(maxout_kernel DEPS ${COMMON_KERNEL_DEPS} maxouting)
kernel_library(maxout_grad_kernel DEPS ${COMMON_KERNEL_DEPS} maxouting)
kernel_library(put_along_axis_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel)
kernel_library(put_along_axis_grad_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel)
kernel_library(take_along_axis_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel)
kernel_library(take_along_axis_grad_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel)
# 4. auto parse and build kernel targets by cmake
register_kernels(EXCLUDES ${COMMON_BAISC_KERNELS} ${MANUAL_BUILD_KERNELS} DEPS ${COMMON_KERNEL_DEPS} ${COMMON_BAISC_KERNELS} )
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/impl/maxout_grad_kernel_impl.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(
maxout_grad, CPU, ALL_LAYOUT, phi::MaxOutGradKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/impl/maxout_kernel_impl.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(maxout, CPU, ALL_LAYOUT, phi::MaxOutKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/put_along_axis_grad_kernel.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/copy_kernel.h"
namespace phi {
template <typename T, typename Context>
void PutAlongAxisGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& index,
const DenseTensor& out_grad,
int axis,
const std::string& reduce,
DenseTensor* x_grad,
DenseTensor* value_grad) {
PADDLE_ENFORCE_EQ(
paddle::platform::is_cpu_place(dev_ctx.GetPlace()),
true,
errors::PreconditionNotMet("PutAlongAxisGradOpKernel only runs on CPU."));
const auto& index_type =
paddle::framework::TransToProtoVarType(index.dtype());
if (x_grad) {
phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
if (index_type == paddle::framework::proto::VarType::INT32) {
paddle::operators::cpu_scatter_input_grad_kernel<T, int32_t>(
// Here passing an unused argument out_grad, because it's
// convenient to instantiate a bunch of template function with the
// same arguments list.
out_grad,
axis,
index,
*x_grad,
dev_ctx);
} else {
paddle::operators::cpu_scatter_input_grad_kernel<T, int64_t>(
out_grad, axis, index, *x_grad, dev_ctx);
}
}
if (value_grad) {
value_grad->Resize(index.dims());
value_grad->mutable_data<T>(dev_ctx.GetPlace());
if (index_type == paddle::framework::proto::VarType::INT32) {
paddle::operators::cpu_gather_kernel<T, int32_t>(
out_grad, axis, index, *value_grad, dev_ctx);
} else if (index_type == paddle::framework::proto::VarType::INT64) {
paddle::operators::cpu_gather_kernel<T, int64_t>(
out_grad, axis, index, *value_grad, dev_ctx);
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(put_along_axis_grad,
CPU,
ALL_LAYOUT,
phi::PutAlongAxisGradKernel,
float,
double,
int,
uint8_t,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/put_along_axis_kernel.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/copy_kernel.h"
namespace phi {
template <typename T, typename Context>
void PutAlongAxisKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& index,
const DenseTensor& value,
int axis,
const std::string& reduce,
DenseTensor* out) {
PADDLE_ENFORCE_EQ(
paddle::platform::is_cpu_place(dev_ctx.GetPlace()),
true,
errors::PreconditionNotMet("PutAlongAxisOpKernel only runs on CPU."));
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
const auto& index_type =
paddle::framework::TransToProtoVarType(index.dtype());
if (reduce == "add") {
if (index_type == paddle::framework::proto::VarType::INT32) {
paddle::operators::cpu_scatter_add_kernel<T, int32_t>(
*out, axis, index, value, dev_ctx);
} else if (index_type == paddle::framework::proto::VarType::INT64) {
paddle::operators::cpu_scatter_add_kernel<T, int64_t>(
*out, axis, index, value, dev_ctx);
}
} else if (reduce == "multiply" || reduce == "mul") {
if (index_type == paddle::framework::proto::VarType::INT32) {
paddle::operators::cpu_scatter_mul_kernel<T, int32_t>(
*out, axis, index, value, dev_ctx);
} else if (index_type == paddle::framework::proto::VarType::INT64) {
paddle::operators::cpu_scatter_mul_kernel<T, int64_t>(
*out, axis, index, value, dev_ctx);
}
} else if (reduce == "assign") {
if (index_type == paddle::framework::proto::VarType::INT32) {
paddle::operators::cpu_scatter_assign_kernel<T, int32_t>(
*out, axis, index, value, dev_ctx);
} else if (index_type == paddle::framework::proto::VarType::INT64) {
paddle::operators::cpu_scatter_assign_kernel<T, int64_t>(
*out, axis, index, value, dev_ctx);
}
} else {
PADDLE_THROW(errors::InvalidArgument(
"can not support reduce: '%s' for scatter kernel, only "
"support reduce op: 'add', 'assign', 'mul' and 'multiply', the "
"defalut reduce "
"op is 'assign' ",
reduce));
return;
}
}
} // namespace phi
PD_REGISTER_KERNEL(put_along_axis,
CPU,
ALL_LAYOUT,
phi::PutAlongAxisKernel,
float,
double,
int,
uint8_t,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/take_along_axis_grad_kernel.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename T, typename Context>
void TakeAlongAxisGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& index,
const DenseTensor& out_grad,
int axis,
DenseTensor* x_grad) {
PADDLE_ENFORCE_EQ(
paddle::platform::is_cpu_place(dev_ctx.GetPlace()),
true,
errors::PreconditionNotMet("This kernel only runs on CPU."));
// We need to know the shape of input matrix to determine the shape of grad
// matrix of input.
x_grad->Resize(x.dims());
dev_ctx.template Alloc<T>(x_grad);
// Set to zero tensor.
phi::funcs::SetConstant<Context, T> functor;
functor(dev_ctx, x_grad, static_cast<T>(0));
const auto& index_type =
paddle::framework::TransToProtoVarType(index.dtype());
if (index_type == paddle::framework::proto::VarType::INT32) {
paddle::operators::cpu_scatter_add_kernel<T, int32_t>(
*x_grad,
axis,
index,
out_grad,
dev_ctx); // the gradient of gather is scatter
} else if (index_type == paddle::framework::proto::VarType::INT64) {
paddle::operators::cpu_scatter_add_kernel<T, int64_t>(
*x_grad, axis, index, out_grad, dev_ctx);
}
}
} // namespace phi
PD_REGISTER_KERNEL(take_along_axis_grad,
CPU,
ALL_LAYOUT,
phi::TakeAlongAxisGradKernel,
float,
double,
int,
uint8_t,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/take_along_axis_kernel.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void TakeAlongAxisKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& index,
int axis,
DenseTensor* out) {
PADDLE_ENFORCE_EQ(
paddle::platform::is_cpu_place(dev_ctx.GetPlace()),
true,
errors::PreconditionNotMet("This kernel only runs on CPU."));
out->Resize(index.dims());
dev_ctx.template Alloc<T>(out);
const auto& index_type =
paddle::framework::TransToProtoVarType(index.dtype());
if (index_type == paddle::framework::proto::VarType::INT32) {
paddle::operators::cpu_gather_kernel<T, int32_t>(
x, axis, index, *out, dev_ctx);
} else if (index_type == paddle::framework::proto::VarType::INT64) {
paddle::operators::cpu_gather_kernel<T, int64_t>(
x, axis, index, *out, dev_ctx);
}
}
} // namespace phi
PD_REGISTER_KERNEL(take_along_axis,
CPU,
ALL_LAYOUT,
phi::TakeAlongAxisKernel,
float,
double,
int,
uint8_t,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/impl/maxout_grad_kernel_impl.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(
maxout_grad, GPU, ALL_LAYOUT, phi::MaxOutGradKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/impl/maxout_kernel_impl.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(maxout, GPU, ALL_LAYOUT, phi::MaxOutKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/put_along_axis_grad_kernel.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/copy_kernel.h"
namespace phi {
template <typename T, typename Context>
void PutAlongAxisGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& index,
const DenseTensor& out_grad,
int axis,
const std::string& reduce,
DenseTensor* x_grad,
DenseTensor* value_grad) {
PADDLE_ENFORCE_EQ(paddle::platform::is_gpu_place(dev_ctx.GetPlace()),
true,
errors::PreconditionNotMet(
"PutAlongAxisGradOpCUDAKernel only runs on GPU."));
const auto& index_type =
paddle::framework::TransToProtoVarType(index.dtype());
if (x_grad) {
phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
if (index_type == paddle::framework::proto::VarType::INT32) {
paddle::operators::gpu_scatter_input_grad_kernel<T, int32_t>(
out_grad, axis, index, *x_grad, dev_ctx);
} else {
paddle::operators::gpu_scatter_input_grad_kernel<T, int64_t>(
out_grad, axis, index, *x_grad, dev_ctx);
}
}
if (value_grad) {
value_grad->Resize(index.dims());
value_grad->mutable_data<T>(dev_ctx.GetPlace());
if (index_type == paddle::framework::proto::VarType::INT32) {
paddle::operators::gpu_gather_kernel<T, int32_t>(
out_grad,
axis,
index,
*value_grad,
dev_ctx); // the gradient of scatter is gather
} else if (index_type == paddle::framework::proto::VarType::INT64) {
paddle::operators::gpu_gather_kernel<T, int64_t>(
out_grad, axis, index, *value_grad, dev_ctx);
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(put_along_axis_grad,
GPU,
ALL_LAYOUT,
phi::PutAlongAxisGradKernel,
float,
double,
int64_t,
int,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/put_along_axis_kernel.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/copy_kernel.h"
namespace phi {
template <typename T, typename Context>
void PutAlongAxisKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& index,
const DenseTensor& value,
int axis,
const std::string& reduce,
DenseTensor* out) {
PADDLE_ENFORCE_EQ(paddle::platform::is_gpu_place(dev_ctx.GetPlace()),
true,
errors::PreconditionNotMet(
"PutAlongAxisCUDAKernel only runs on GPU device."));
const auto& index_type =
paddle::framework::TransToProtoVarType(index.dtype());
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
if (reduce == "add") {
if (index_type == paddle::framework::proto::VarType::INT32) {
paddle::operators::gpu_scatter_add_kernel<T, int32_t>(
*out, axis, index, value, dev_ctx);
} else if (index_type == paddle::framework::proto::VarType::INT64) {
paddle::operators::gpu_scatter_add_kernel<T, int64_t>(
*out, axis, index, value, dev_ctx);
}
} else if (reduce == "multiply" || reduce == "mul") {
if (index_type == paddle::framework::proto::VarType::INT32) {
paddle::operators::gpu_scatter_mul_kernel<T, int32_t>(
*out, axis, index, value, dev_ctx);
} else if (index_type == paddle::framework::proto::VarType::INT64) {
paddle::operators::gpu_scatter_mul_kernel<T, int64_t>(
*out, axis, index, value, dev_ctx);
}
} else if (reduce == "assign") {
if (index_type == paddle::framework::proto::VarType::INT32) {
paddle::operators::gpu_scatter_assign_kernel<T, int32_t>(
*out, axis, index, value, dev_ctx);
} else if (index_type == paddle::framework::proto::VarType::INT64) {
paddle::operators::gpu_scatter_assign_kernel<T, int64_t>(
*out, axis, index, value, dev_ctx);
}
} else {
PADDLE_THROW(errors::InvalidArgument(
"can not support reduce: '%s' for scatter kernel, only "
"support reduce op: 'add', 'assign', 'mul' and 'multiply', the "
"defalut reduce op is 'assign' ",
reduce));
return;
}
}
} // namespace phi
PD_REGISTER_KERNEL(put_along_axis,
GPU,
ALL_LAYOUT,
phi::PutAlongAxisKernel,
float,
double,
int64_t,
int,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/take_along_axis_grad_kernel.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename T, typename Context>
void TakeAlongAxisGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& index,
const DenseTensor& out_grad,
int axis,
DenseTensor* x_grad) {
PADDLE_ENFORCE_EQ(
paddle::platform::is_gpu_place(dev_ctx.GetPlace()),
true,
errors::PreconditionNotMet("This kernel only runs on GPU."));
// We need to know the shape of input matrix to determine the shape of grad
// matrix of input.
x_grad->Resize(x.dims());
dev_ctx.template Alloc<T>(x_grad);
// Set to zero tensor.
phi::funcs::SetConstant<Context, T> functor;
functor(dev_ctx, x_grad, static_cast<T>(0));
const auto& index_type =
paddle::framework::TransToProtoVarType(index.dtype());
if (index_type == paddle::framework::proto::VarType::INT32) {
paddle::operators::gpu_scatter_add_kernel<T, int32_t>(
*x_grad,
axis,
index,
out_grad,
dev_ctx); // the gradient of gather is scatter
} else if (index_type == paddle::framework::proto::VarType::INT64) {
paddle::operators::gpu_scatter_add_kernel<T, int64_t>(
*x_grad, axis, index, out_grad, dev_ctx);
}
}
} // namespace phi
PD_REGISTER_KERNEL(take_along_axis_grad,
GPU,
ALL_LAYOUT,
phi::TakeAlongAxisGradKernel,
float,
double,
int64_t,
int,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/take_along_axis_kernel.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void TakeAlongAxisKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& index,
int axis,
DenseTensor* out) {
PADDLE_ENFORCE_EQ(
paddle::platform::is_gpu_place(dev_ctx.GetPlace()),
true,
errors::PreconditionNotMet("This kernel only runs on GPU device."));
out->Resize(index.dims());
dev_ctx.template Alloc<T>(out);
const auto& index_type =
paddle::framework::TransToProtoVarType(index.dtype());
if (index_type == paddle::framework::proto::VarType::INT32) {
paddle::operators::gpu_gather_kernel<T, int32_t>(
x, axis, index, *out, dev_ctx);
} else if (index_type == paddle::framework::proto::VarType::INT64) {
paddle::operators::gpu_gather_kernel<T, int64_t>(
x, axis, index, *out, dev_ctx);
}
}
} // namespace phi
PD_REGISTER_KERNEL(take_along_axis,
GPU,
ALL_LAYOUT,
phi::TakeAlongAxisKernel,
double,
int64_t,
int,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/maxout_grad_kernel.h"
#include "paddle/fluid/operators/math/maxouting.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename T, typename Context>
void MaxOutGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out,
const DenseTensor& out_grad,
int groups,
int axis,
DenseTensor* x_grad) {
if (axis < 0) {
axis += x.dims().size();
}
phi::funcs::SetConstant<Context, T> zero;
if (x_grad) {
dev_ctx.template Alloc<T>(x_grad);
zero(dev_ctx, x_grad, static_cast<T>(0.0));
paddle::operators::math::MaxOutGradFunctor<Context, T> maxout_backward;
maxout_backward(dev_ctx, x, x_grad, out, out_grad, groups, axis);
}
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/maxout_kernel.h"
#include "paddle/fluid/operators/math/maxouting.h"
namespace phi {
template <typename T, typename Context>
void MaxOutKernel(const Context& dev_ctx,
const DenseTensor& x,
int groups,
int axis,
DenseTensor* out) {
if (axis < 0) {
axis += x.dims().size();
}
paddle::operators::math::MaxOutFunctor<Context, T> maxout_forward;
maxout_forward(dev_ctx, x, out, groups, axis);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void MaxOutGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out,
const DenseTensor& out_grad,
int groups,
int axis,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void MaxOutKernel(const Context& dev_ctx,
const DenseTensor& x,
int groups,
int axis,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void PutAlongAxisGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& index,
const DenseTensor& out_grad,
int axis,
const std::string& reduce,
DenseTensor* x_grad,
DenseTensor* value_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void PutAlongAxisKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& index,
const DenseTensor& value,
int axis,
const std::string& reduce,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void TakeAlongAxisGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& index,
const DenseTensor& out_grad,
int axis,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void TakeAlongAxisKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& index,
int axis,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature MaxoutArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("maxout", {"X"}, {"groups", "axis"}, {"Out"});
}
KernelSignature MaxoutGradArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("maxout_grad",
{"X", "Out", GradVarName("Out")},
{"groups", "axis"},
{GradVarName("X")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(maxout, phi::MaxoutArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(maxout_grad, phi::MaxoutGradArgumentMapping);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature PutAlongAxisArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("put_along_axis",
{"Input", "Index", "Value"},
{"Axis", "Reduce"},
{"Result"});
}
KernelSignature PutAlongAxisGradArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("put_along_axis_grad",
{"Input", "Index", GradVarName("Result")},
{"Axis", "Reduce"},
{GradVarName("Input"), GradVarName("Value")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(put_along_axis, phi::PutAlongAxisArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(put_along_axis_grad,
phi::PutAlongAxisGradArgumentMapping);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature TakeAlongAxisArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"take_along_axis", {"Input", "Index"}, {"Axis"}, {"Result"});
}
KernelSignature TakeAlongAxisGradArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("take_along_axis_grad",
{"Input", "Index", GradVarName("Result")},
{"Axis"},
{GradVarName("Input")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(take_along_axis, phi::TakeAlongAxisArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(take_along_axis_grad,
phi::TakeAlongAxisGradArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册