未验证 提交 fe911a51 编写于 作者: C Chenxiao Niu 提交者: GitHub

add concat_grad mlu kernel. (#43117)

上级 d999049f
......@@ -74,6 +74,64 @@ class ConcatMLUKernel : public framework::OpKernel<T> {
output_desc.get(), GetBasePtr(out));
}
};
template <typename T>
class ConcatGradMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* out_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto ins = ctx.MultiInput<framework::LoDTensor>("X");
auto out_var_names = ctx.OutputNames(framework::GradVarName("X"));
auto outs =
ctx.MultiOutput<framework::LoDTensor>(framework::GradVarName("X"));
auto axis = ctx.Attr<int>("axis");
int split_num = ins.size();
PADDLE_ENFORCE_NOT_NULL(ins[0],
platform::errors::NotFound(
"The first input tensor is not initalized."));
if (ctx.HasInput("AxisTensor")) {
auto* axis_tensor = ctx.Input<framework::Tensor>("AxisTensor");
axis = GetDataFromTensor<int>(axis_tensor)[0];
}
axis = ComputeAxis(static_cast<int64_t>(axis),
static_cast<int64_t>(ins[0]->dims().size()));
PADDLE_ENFORCE_GE(axis, 0, platform::errors::InvalidArgument(
"concat_grad: axis should be larger than or "
"equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LT(
axis, out_grad->dims().size(),
platform::errors::InvalidArgument(
"concat_grad: axis should be less than ins[0]->dims()!"
"But received axis is %d, while ins[0]->dims()"
"size is %d.",
axis, out_grad->dims().size()));
// get output tensor that the name is not kEmptyVarName
std::vector<void*> outputs_vec;
std::vector<MLUCnnlTensorDesc> output_descs;
std::vector<cnnlTensorDescriptor_t> descs_vec;
for (size_t j = 0; j < outs.size(); ++j) {
if (out_var_names[j] != framework::kEmptyVarName &&
outs[j]->numel() != 0UL) {
outs[j]->mutable_data<T>(ctx.GetPlace());
output_descs.emplace_back(MLUCnnlTensorDesc(*outs[j]));
descs_vec.push_back(output_descs.back().get());
outputs_vec.push_back(GetBasePtr(outs[j]));
} else {
outputs_vec.push_back(nullptr);
}
}
MLUCnnlTensorDesc out_grad_desc(*out_grad);
MLUCnnl::Split(ctx, static_cast<int>(split_num), static_cast<int>(axis),
out_grad_desc.get(), GetBasePtr(out_grad), descs_vec.data(),
outputs_vec.data());
}
};
} // namespace operators
} // namespace paddle
......@@ -84,3 +142,9 @@ REGISTER_OP_MLU_KERNEL(concat, ops::ConcatMLUKernel<float>,
ops::ConcatMLUKernel<int64_t>,
ops::ConcatMLUKernel<bool>, ops::ConcatMLUKernel<int>,
ops::ConcatMLUKernel<uint8_t>);
REGISTER_OP_MLU_KERNEL(concat_grad, ops::ConcatGradMLUKernel<float>,
ops::ConcatGradMLUKernel<paddle::platform::float16>,
ops::ConcatGradMLUKernel<int64_t>,
ops::ConcatGradMLUKernel<bool>,
ops::ConcatGradMLUKernel<int>,
ops::ConcatGradMLUKernel<uint8_t>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册