提交 8b1048b4 编写于 作者: L Leo Chen 提交者: GitHub

Revert "[pten] remove concat fluid kernel (#39268)"

This reverts commit 552db8dc.
上级 a909bdf1
......@@ -244,7 +244,19 @@ REGISTER_OPERATOR(concat_grad, ops::ConcatOpGrad,
ops::ConcatDoubleGradOpMaker<paddle::framework::OpDesc>,
ops::ConcatDoubleGradOpMaker<paddle::imperative::OpBase>,
ops::ConcatOpGradNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL(
concat, ops::ConcatKernel<paddle::platform::CPUDeviceContext, double>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext, float>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext, bool>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext, int>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
concat_grad,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, double>,
......
......@@ -19,7 +19,18 @@ limitations under the License. */
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
concat, ops::ConcatKernel<paddle::platform::CUDADeviceContext, double>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext, float>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext, bool>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext, int>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext,
plat::complex<float>>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext,
plat::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
concat_grad,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, double>,
......
......@@ -39,6 +39,54 @@ static inline int64_t ComputeAxis(int64_t axis, int64_t rank) {
}
return axis > 0 ? axis : 0;
}
template <typename DeviceContext, typename T>
class ConcatKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto ins = ctx.MultiInput<framework::LoDTensor>("X");
framework::LoDTensor* out = ctx.Output<framework::LoDTensor>("Out");
PADDLE_ENFORCE_NOT_NULL(ins[0],
platform::errors::NotFound(
"The first input tensor is not initalized."));
auto axis = ctx.Attr<int>("axis");
bool need_resize_out_dims = false;
if (ctx.HasInput("AxisTensor")) {
auto* axis_tensor = ctx.Input<framework::Tensor>("AxisTensor");
axis = GetDataFromTensor<int>(axis_tensor)[0];
need_resize_out_dims = true;
}
axis = ComputeAxis(static_cast<int64_t>(axis),
static_cast<int64_t>(ins[0]->dims().size()));
if (need_resize_out_dims) {
const size_t n = ins.size();
std::vector<framework::DDim> ins_dims(n);
for (size_t i = 0; i < n; i++) {
ins_dims[i] = ins[i]->dims();
}
framework::DDim out_dims =
pten::funcs::ComputeAndCheckShape(true, ins_dims, axis);
out->Resize(out_dims);
}
auto place = ctx.GetPlace();
out->mutable_data<T>(place);
// call new kernel
auto& dev_ctx = ctx.device_context<DeviceContext>();
std::vector<pten::DenseTensor> pt_ins;
for (auto& in : ins) {
pt_ins.push_back(*in);
}
pten::ConcatKernel<T>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
pt_ins, axis, out);
}
};
template <typename DeviceContext, typename T>
class ConcatGradKernel : public framework::OpKernel<T> {
public:
......
......@@ -299,7 +299,7 @@ class TensorArrayToTensorGradOpMaker : public framework::SingleGradOpMaker<T> {
} // namespace operators
} // namespace paddle
USE_OP_ITSELF(concat);
USE_OP(concat);
namespace ops = paddle::operators;
REGISTER_OPERATOR(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册