未验证 提交 d480d7b1 编写于 作者: 0 0x45f 提交者: GitHub

Move lerp OP to pten (#39524)

* move lerp to pten

* refine include

* move files

* refine code
上级 ff7e3590
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/lerp_op.h" #include "paddle/fluid/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -132,15 +132,3 @@ REGISTER_OPERATOR( ...@@ -132,15 +132,3 @@ REGISTER_OPERATOR(
paddle::operators::LerpInplaceInferer); paddle::operators::LerpInplaceInferer);
REGISTER_OPERATOR(lerp_grad, paddle::operators::LerpGradOp); REGISTER_OPERATOR(lerp_grad, paddle::operators::LerpGradOp);
REGISTER_OP_CPU_KERNEL(
lerp,
paddle::operators::LerpKernel<paddle::platform::CPUDeviceContext, float>,
paddle::operators::LerpKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
lerp_grad,
paddle::operators::LerpGradKernel<paddle::platform::CPUDeviceContext,
float>,
paddle::operators::LerpGradKernel<paddle::platform::CPUDeviceContext,
double>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#ifdef _WIN32
#ifndef NOMINMAX
#define NOMINMAX // msvc max/min macro conflict with std::min/max
#endif
#endif
namespace paddle {
namespace operators {
static framework::DDim ExtendDims2Rank(const framework::DDim& in_dims,
int rank) {
if (in_dims.size() == rank) {
return in_dims;
}
std::vector<int64_t> shapes(rank, 1);
for (int i = in_dims.size() - 1, j = rank - 1; i >= 0; --i, --j) {
shapes[j] = in_dims[i];
}
return framework::make_ddim(shapes);
}
template <size_t D>
static void GetBroadcastDims(const framework::DDim& in_dims,
const framework::DDim& out_dims,
Eigen::DSizes<int, D>* bcast_dims) {
for (size_t i = 0; i < D; ++i) {
if (in_dims[i] == out_dims[i]) {
(*bcast_dims)[i] = 1;
} else {
(*bcast_dims)[i] = std::max(in_dims[i], out_dims[i]);
}
}
}
template <typename DeviceContext, typename T, size_t D>
static void LerpFunction(const framework::ExecutionContext& ctx) {
auto x = ctx.Input<framework::Tensor>("X");
auto y = ctx.Input<framework::Tensor>("Y");
auto w = ctx.Input<framework::Tensor>("Weight");
auto out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
auto out_dims = out->dims();
auto x_dims = ExtendDims2Rank(x->dims(), D);
auto y_dims = ExtendDims2Rank(y->dims(), D);
auto w_dims = ExtendDims2Rank(w->dims(), D);
Eigen::DSizes<int, D> x_bcast_dims;
Eigen::DSizes<int, D> y_bcast_dims;
Eigen::DSizes<int, D> w_bcast_dims;
GetBroadcastDims<D>(x_dims, out_dims, &x_bcast_dims);
GetBroadcastDims<D>(y_dims, out_dims, &y_bcast_dims);
GetBroadcastDims<D>(w_dims, out_dims, &w_bcast_dims);
auto eigen_x = framework::EigenTensor<T, D>::From(*x, x_dims);
auto eigen_y = framework::EigenTensor<T, D>::From(*y, y_dims);
auto eigen_w = framework::EigenTensor<T, D>::From(*w, w_dims);
auto eigen_out = framework::EigenTensor<T, D>::From(*out);
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
eigen_out.device(place) =
eigen_x.broadcast(x_bcast_dims) +
eigen_w.broadcast(w_bcast_dims) *
(eigen_y.broadcast(y_bcast_dims) - eigen_x.broadcast(x_bcast_dims));
}
template <typename DeviceContext, typename T, size_t D>
static void LerpGradFunction(const framework::ExecutionContext& ctx) {
auto w = ctx.Input<framework::Tensor>("Weight");
auto dout = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto dy = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
auto dout_dims = dout->dims();
auto dx_dims = ExtendDims2Rank(dx->dims(), D);
auto dy_dims = ExtendDims2Rank(dy->dims(), D);
auto w_dims = ExtendDims2Rank(w->dims(), D);
Eigen::DSizes<int, D> dx_bcast_dims;
Eigen::DSizes<int, D> dy_bcast_dims;
Eigen::DSizes<int, D> w_bcast_dims;
GetBroadcastDims<D>(dx_dims, dout_dims, &dx_bcast_dims);
GetBroadcastDims<D>(dy_dims, dout_dims, &dy_bcast_dims);
GetBroadcastDims<D>(w_dims, dout_dims, &w_bcast_dims);
auto eigen_w = framework::EigenTensor<T, D>::From(*w, w_dims);
auto eigen_dout = framework::EigenTensor<T, D>::From(*dout);
Eigen::DSizes<int, D * 2> dx_reshape_dims;
Eigen::DSizes<int, D * 2> dy_reshape_dims;
Eigen::DSizes<int, D> reduce_dims;
for (int i = 0; i < dout_dims.size(); ++i) {
dx_reshape_dims[2 * i] = dx_bcast_dims[i];
dx_reshape_dims[2 * i + 1] = dx_dims[i];
dy_reshape_dims[2 * i] = dy_bcast_dims[i];
dy_reshape_dims[2 * i + 1] = dy_dims[i];
reduce_dims[i] = 2 * i;
}
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
if (dx) {
dx->mutable_data<T>(ctx.GetPlace());
auto eigen_dx = framework::EigenTensor<T, D>::From(*dx, dx_dims);
auto eigen_expr = (1 - eigen_w.broadcast(w_bcast_dims)) * eigen_dout;
eigen_dx.device(place) = eigen_expr.reshape(dx_reshape_dims)
.sum(reduce_dims)
.reshape(eigen_dx.dimensions());
}
if (dy) {
dy->mutable_data<T>(ctx.GetPlace());
auto eigen_dy = framework::EigenTensor<T, D>::From(*dy, dy_dims);
auto eigen_expr = eigen_w.broadcast(w_bcast_dims) * eigen_dout;
eigen_dy.device(place) = eigen_expr.reshape(dy_reshape_dims)
.sum(reduce_dims)
.reshape(eigen_dy.dimensions());
}
}
template <typename DeviceContext, typename T>
class LerpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
int rank = ctx.Output<framework::Tensor>("Out")->dims().size();
PADDLE_ENFORCE_GE(
rank, 1,
platform::errors::InvalidArgument(
"The number of dimensions for LerpOp must be "
"greater than or equal to 1, but the value received is %d.",
rank));
PADDLE_ENFORCE_LE(
rank, 6, platform::errors::InvalidArgument(
"The number of dimensions for LerpOp must be "
"less than or equal to 6, but the value received is %d.",
rank));
switch (rank) {
case 1:
LerpFunction<DeviceContext, T, 1>(ctx);
break;
case 2:
LerpFunction<DeviceContext, T, 2>(ctx);
break;
case 3:
LerpFunction<DeviceContext, T, 3>(ctx);
break;
case 4:
LerpFunction<DeviceContext, T, 4>(ctx);
break;
case 5:
LerpFunction<DeviceContext, T, 5>(ctx);
break;
case 6:
LerpFunction<DeviceContext, T, 6>(ctx);
break;
}
}
};
template <typename DeviceContext, typename T>
class LerpGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
int rank = ctx.Input<framework::Tensor>(framework::GradVarName("Out"))
->dims()
.size();
PADDLE_ENFORCE_GE(
rank, 1,
platform::errors::InvalidArgument(
"The number of dimensions for LerpGradOp must be "
"greater than or equal to 1, but the value received is %d.",
rank));
PADDLE_ENFORCE_LE(
rank, 6, platform::errors::InvalidArgument(
"The number of dimensions for LerpGradOp must be "
"less than or equal to 6, but the value received is %d.",
rank));
switch (rank) {
case 1:
LerpGradFunction<DeviceContext, T, 1>(ctx);
break;
case 2:
LerpGradFunction<DeviceContext, T, 2>(ctx);
break;
case 3:
LerpGradFunction<DeviceContext, T, 3>(ctx);
break;
case 4:
LerpGradFunction<DeviceContext, T, 4>(ctx);
break;
case 5:
LerpGradFunction<DeviceContext, T, 5>(ctx);
break;
case 6:
LerpGradFunction<DeviceContext, T, 6>(ctx);
break;
}
}
};
} // namespace operators
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/kernels/lerp_grad_kernel.h"
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/impl/lerp_grad_kernel_impl.h"
PT_REGISTER_KERNEL(
lerp_grad, CPU, ALL_LAYOUT, pten::LerpGradKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/kernels/lerp_kernel.h"
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/impl/lerp_kernel_impl.h"
PT_REGISTER_KERNEL(lerp, CPU, ALL_LAYOUT, pten::LerpKernel, float, double) {}
...@@ -102,5 +102,30 @@ inline void GetPrePostNumel( ...@@ -102,5 +102,30 @@ inline void GetPrePostNumel(
} }
} }
static framework::DDim ExtendDims2Rank(const framework::DDim &in_dims,
int rank) {
if (in_dims.size() == rank) {
return in_dims;
}
std::vector<int64_t> shapes(rank, 1);
for (int i = in_dims.size() - 1, j = rank - 1; i >= 0; --i, --j) {
shapes[j] = in_dims[i];
}
return framework::make_ddim(shapes);
}
template <size_t D>
static void GetBroadcastDims(const framework::DDim &in_dims,
const framework::DDim &out_dims,
Eigen::DSizes<int, D> *bcast_dims) {
for (size_t i = 0; i < D; ++i) {
if (in_dims[i] == out_dims[i]) {
(*bcast_dims)[i] = 1;
} else {
(*bcast_dims)[i] = std::max(in_dims[i], out_dims[i]);
}
}
}
} // namespace funcs } // namespace funcs
} // namespace pten } // namespace pten
...@@ -12,16 +12,10 @@ ...@@ -12,16 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/lerp_op.h" #include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/impl/lerp_grad_kernel_impl.h"
#include "paddle/pten/kernels/lerp_grad_kernel.h"
REGISTER_OP_CUDA_KERNEL( PT_REGISTER_KERNEL(
lerp, lerp_grad, GPU, ALL_LAYOUT, pten::LerpGradKernel, float, double) {}
paddle::operators::LerpKernel<paddle::platform::CUDADeviceContext, float>,
paddle::operators::LerpKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
lerp_grad,
paddle::operators::LerpGradKernel<paddle::platform::CUDADeviceContext,
float>,
paddle::operators::LerpGradKernel<paddle::platform::CUDADeviceContext,
double>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/impl/lerp_kernel_impl.h"
#include "paddle/pten/kernels/lerp_kernel.h"
PT_REGISTER_KERNEL(lerp, GPU, ALL_LAYOUT, pten::LerpKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/pten/kernels/funcs/common_shape.h"
#include "paddle/pten/kernels/funcs/eigen/common.h"
namespace pten {
template <typename Context, typename T, size_t D>
static void LerpGradFunction(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& weight,
const DenseTensor& out,
const DenseTensor& out_grad,
DenseTensor* x_grad,
DenseTensor* y_grad) {
auto& w = weight;
auto& dout = out_grad;
auto* dx = x_grad;
auto* dy = y_grad;
auto dout_dims = dout.dims();
auto dx_dims = pten::funcs::ExtendDims2Rank(dx->dims(), D);
auto dy_dims = pten::funcs::ExtendDims2Rank(dy->dims(), D);
auto w_dims = pten::funcs::ExtendDims2Rank(w.dims(), D);
Eigen::DSizes<int, D> dx_bcast_dims;
Eigen::DSizes<int, D> dy_bcast_dims;
Eigen::DSizes<int, D> w_bcast_dims;
pten::funcs::GetBroadcastDims<D>(dx_dims, dout_dims, &dx_bcast_dims);
pten::funcs::GetBroadcastDims<D>(dy_dims, dout_dims, &dy_bcast_dims);
pten::funcs::GetBroadcastDims<D>(w_dims, dout_dims, &w_bcast_dims);
auto eigen_w = pten::EigenTensor<T, D>::From(w, w_dims);
auto eigen_dout = pten::EigenTensor<T, D>::From(dout);
Eigen::DSizes<int, D * 2> dx_reshape_dims;
Eigen::DSizes<int, D * 2> dy_reshape_dims;
Eigen::DSizes<int, D> reduce_dims;
for (int i = 0; i < dout_dims.size(); ++i) {
dx_reshape_dims[2 * i] = dx_bcast_dims[i];
dx_reshape_dims[2 * i + 1] = dx_dims[i];
dy_reshape_dims[2 * i] = dy_bcast_dims[i];
dy_reshape_dims[2 * i + 1] = dy_dims[i];
reduce_dims[i] = 2 * i;
}
auto& place = *ctx.eigen_device();
if (dx) {
ctx.template Alloc<T>(dx);
auto eigen_dx = pten::EigenTensor<T, D>::From(*dx, dx_dims);
auto eigen_expr = (1 - eigen_w.broadcast(w_bcast_dims)) * eigen_dout;
eigen_dx.device(place) = eigen_expr.reshape(dx_reshape_dims)
.sum(reduce_dims)
.reshape(eigen_dx.dimensions());
}
if (dy) {
ctx.template Alloc<T>(dy);
auto eigen_dy = pten::EigenTensor<T, D>::From(*dy, dy_dims);
auto eigen_expr = eigen_w.broadcast(w_bcast_dims) * eigen_dout;
eigen_dy.device(place) = eigen_expr.reshape(dy_reshape_dims)
.sum(reduce_dims)
.reshape(eigen_dy.dimensions());
}
}
template <typename T, typename Context>
void LerpGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& weight,
const DenseTensor& out,
const DenseTensor& out_grad,
DenseTensor* x_grad,
DenseTensor* y_grad) {
int rank = out.dims().size();
PADDLE_ENFORCE_GE(
rank,
1,
pten::errors::InvalidArgument(
"The number of dimensions for LerpGradOp must be "
"greater than or equal to 1, but the value received is %d.",
rank));
PADDLE_ENFORCE_LE(
rank,
6,
pten::errors::InvalidArgument(
"The number of dimensions for LerpGradOp must be "
"less than or equal to 6, but the value received is %d.",
rank));
switch (rank) {
case 1:
LerpGradFunction<Context, T, 1>(
ctx, x, y, weight, out, out_grad, x_grad, y_grad);
break;
case 2:
LerpGradFunction<Context, T, 2>(
ctx, x, y, weight, out, out_grad, x_grad, y_grad);
break;
case 3:
LerpGradFunction<Context, T, 3>(
ctx, x, y, weight, out, out_grad, x_grad, y_grad);
break;
case 4:
LerpGradFunction<Context, T, 4>(
ctx, x, y, weight, out, out_grad, x_grad, y_grad);
break;
case 5:
LerpGradFunction<Context, T, 5>(
ctx, x, y, weight, out, out_grad, x_grad, y_grad);
break;
case 6:
LerpGradFunction<Context, T, 6>(
ctx, x, y, weight, out, out_grad, x_grad, y_grad);
break;
}
}
} // namespace pten
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/kernels/funcs/common_shape.h"
#include "paddle/pten/kernels/funcs/eigen/common.h"
namespace pten {
template <typename Context, typename T, size_t D>
static void LerpFunction(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& weight,
DenseTensor* out) {
ctx.template Alloc<T>(out);
auto out_dims = out->dims();
auto x_dims = pten::funcs::ExtendDims2Rank(x.dims(), D);
auto y_dims = pten::funcs::ExtendDims2Rank(y.dims(), D);
auto w_dims = pten::funcs::ExtendDims2Rank(weight.dims(), D);
Eigen::DSizes<int, D> x_bcast_dims;
Eigen::DSizes<int, D> y_bcast_dims;
Eigen::DSizes<int, D> w_bcast_dims;
pten::funcs::GetBroadcastDims<D>(x_dims, out_dims, &x_bcast_dims);
pten::funcs::GetBroadcastDims<D>(y_dims, out_dims, &y_bcast_dims);
pten::funcs::GetBroadcastDims<D>(w_dims, out_dims, &w_bcast_dims);
auto eigen_x = pten::EigenTensor<T, D>::From(x, x_dims);
auto eigen_y = pten::EigenTensor<T, D>::From(y, y_dims);
auto eigen_w = pten::EigenTensor<T, D>::From(weight, w_dims);
auto eigen_out = pten::EigenTensor<T, D>::From(*out);
auto& place = *ctx.eigen_device();
eigen_out.device(place) =
eigen_x.broadcast(x_bcast_dims) +
eigen_w.broadcast(w_bcast_dims) *
(eigen_y.broadcast(y_bcast_dims) - eigen_x.broadcast(x_bcast_dims));
}
template <typename T, typename Context>
void LerpKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& weight,
DenseTensor* out) {
int rank = out->dims().size();
PADDLE_ENFORCE_GE(
rank,
1,
pten::errors::InvalidArgument(
"The number of dimensions for LerpOp must be "
"greater than or equal to 1, but the value received is %d.",
rank));
PADDLE_ENFORCE_LE(
rank,
6,
pten::errors::InvalidArgument(
"The number of dimensions for LerpOp must be "
"less than or equal to 6, but the value received is %d.",
rank));
switch (rank) {
case 1:
LerpFunction<Context, T, 1>(ctx, x, y, weight, out);
break;
case 2:
LerpFunction<Context, T, 2>(ctx, x, y, weight, out);
break;
case 3:
LerpFunction<Context, T, 3>(ctx, x, y, weight, out);
break;
case 4:
LerpFunction<Context, T, 4>(ctx, x, y, weight, out);
break;
case 5:
LerpFunction<Context, T, 5>(ctx, x, y, weight, out);
break;
case 6:
LerpFunction<Context, T, 6>(ctx, x, y, weight, out);
break;
}
}
} // namespace pten
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/pten/core/dense_tensor.h"
namespace pten {
template <typename T, typename Context>
void LerpGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& weight,
const DenseTensor& out,
const DenseTensor& out_grad,
DenseTensor* x_grad,
DenseTensor* y_grad);
} // namespace pten
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/pten/core/dense_tensor.h"
namespace pten {
template <typename T, typename Context>
void LerpKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& weight,
DenseTensor* out);
} // namespace pten
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/core/compat/op_utils.h"
namespace pten {
KernelSignature LerpOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("lerp", {"X", "Y", "Weight"}, {}, {"Out"});
}
KernelSignature LerpGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("lerp_grad",
{"X", "Y", "Weight", "Out", GradVarName("Out")},
{},
{GradVarName("X"), GradVarName("Y")});
}
} // namespace pten
PT_REGISTER_ARG_MAPPING_FN(lerp, pten::LerpOpArgumentMapping);
PT_REGISTER_ARG_MAPPING_FN(lerp_grad, pten::LerpGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册