未验证 提交 8497e2aa 编写于 作者: L Leo Chen 提交者: GitHub

[NPU] add npu kernel for elementwise_add_grad (#31347)

* fix reading flags from env

* fix problem caused by async run

* support partial grad

* support elementwise_add_grad npu kernel

* add unittest

* fix bug?
上级 9fcdaeba
...@@ -12,17 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,17 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef PADDLE_WITH_ASCEND_CL
#include <memory> #include <memory>
#include <string> #include <string>
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h" #include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/operators/npu_op_runner.h" #include "paddle/fluid/operators/npu_op_runner.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T> template <typename T>
class ElementwiseAddNPUKernel : public framework::OpKernel<T> { class ElementwiseAddNPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -39,12 +40,127 @@ class ElementwiseAddNPUKernel : public framework::OpKernel<T> { ...@@ -39,12 +40,127 @@ class ElementwiseAddNPUKernel : public framework::OpKernel<T> {
} }
}; };
template <typename T>
class ElementwiseAddGradNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
// NOTE(zhiqiu): It seems Ascend Sub follow the broadcast sematics with
// default axis=-1?
// So, the sub_grad should do reduce if needed.
// For example, the shape of each variable in elementwise_sub:
// x, dx: [2, 3, 5]
// y, dy: [1, 5]
// out, dout: [2, 3, 5]
// Then, out = x - y => dx = dout, dy = -dout
// And, the shape of dy can be computed by two stages reduce,
// 1. [2, 3, 5] => [3, 5], ReduceSumD on axis = 0, keep_dims = false.
// 2. [3, 5] => [1, 5], ReduceSumD on axis = 0, keep_dims = true.
if (dx) {
dx->mutable_data<T>(ctx.GetPlace());
// For dx
// stage 1
auto reduce_ndim = dout->dims().size() - dx->dims().size();
std::vector<int> axes;
for (auto i = 0; i < reduce_ndim; ++i) {
axes.push_back(i);
}
Tensor* tmp_dout = const_cast<Tensor*>(dout);
Tensor reduced_dout(dx->type());
if (axes.size() != 0) {
std::vector<int64_t> reduced_dout_dims;
for (auto i = reduce_ndim; i < dout->dims().size(); ++i) {
reduced_dout_dims.push_back(dout->dims()[i]);
}
reduced_dout.Resize(framework::make_ddim(reduced_dout_dims));
reduced_dout.mutable_data<T>(ctx.GetPlace());
auto runner = NpuOpRunner("ReduceSumD", {*dout}, {reduced_dout},
{{"axes", axes}, {"keep_dims", false}});
runner.Run(stream);
tmp_dout = &reduced_dout;
}
// stage 2
axes.clear();
for (auto i = 0; i < dx->dims().size(); ++i) {
if (dx->dims()[i] == 1) {
axes.push_back(i);
}
}
if (axes.size() != 0) {
auto runner = NpuOpRunner("ReduceSumD", {*tmp_dout}, {*dx},
{{"axes", axes}, {"keep_dims", true}});
runner.Run(stream);
} else {
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.Wait();
framework::TensorCopySync(*tmp_dout, ctx.GetPlace(), dx);
}
}
if (dy) {
// For dy
// stage 1
auto reduce_ndim = dout->dims().size() - dy->dims().size();
std::vector<int> axes;
for (auto i = 0; i < reduce_ndim; ++i) {
axes.push_back(i);
}
Tensor* tmp_dout = const_cast<Tensor*>(dout);
Tensor reduced_dout(dout->type());
if (axes.size() != 0) {
std::vector<int64_t> reduced_dout_dims;
for (auto i = reduce_ndim; i < dout->dims().size(); ++i) {
reduced_dout_dims.push_back(dout->dims()[i]);
}
reduced_dout.Resize(framework::make_ddim(reduced_dout_dims));
reduced_dout.mutable_data<T>(ctx.GetPlace());
auto runner = NpuOpRunner("ReduceSumD", {*dout}, {reduced_dout},
{{"axes", axes}, {"keep_dims", false}});
runner.Run(stream);
tmp_dout = &reduced_dout;
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.Wait();
}
// stage 2
axes.clear();
for (auto i = 0; i < dy->dims().size(); ++i) {
if (dy->dims()[i] == 1) {
axes.push_back(i);
}
}
if (axes.size() != 0) {
dy->mutable_data<T>(ctx.GetPlace());
auto runner = NpuOpRunner("ReduceSumD", {*tmp_dout}, {*dy},
{{"axes", axes}, {"keep_dims", true}});
runner.Run(stream);
} else {
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.Wait();
framework::TensorCopySync(*tmp_dout, ctx.GetPlace(), dy);
}
}
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL(elementwise_add, ops::ElementwiseAddNPUKernel<float>,
ops::ElementwiseAddNPUKernel<plat::float16>);
REGISTER_OP_NPU_KERNEL( REGISTER_OP_NPU_KERNEL(elementwise_add_grad,
elementwise_add, ops::ElementwiseAddGradNPUKernel<float>,
ops::ElementwiseAddNPUKernel<paddle::platform::NPUDeviceContext, float>); ops::ElementwiseAddGradNPUKernel<plat::float16>);
#endif
...@@ -74,6 +74,7 @@ void Compare(f::Scope* scope, const p::DeviceContext& ctx, ...@@ -74,6 +74,7 @@ void Compare(f::Scope* scope, const p::DeviceContext& ctx,
{{"Out", {"Out"}}}, attrs); {{"Out", {"Out"}}}, attrs);
op->Run(*scope, place); op->Run(*scope, place);
ctx.Wait();
std::vector<T> out_vec; std::vector<T> out_vec;
TensorToVector(*tensor_out, ctx, &out_vec); TensorToVector(*tensor_out, ctx, &out_vec);
...@@ -125,12 +126,13 @@ void CompareGrad(f::Scope* scope, const p::DeviceContext& ctx, ...@@ -125,12 +126,13 @@ void CompareGrad(f::Scope* scope, const p::DeviceContext& ctx,
// run // run
f::AttributeMap attrs; f::AttributeMap attrs;
auto op = f::OpRegistry::CreateOp(op_type, auto op = f::OpRegistry::CreateOp(
{{"Out@GRAD", {"DOut"}}, {"X", {"X"}}, {"Y", {"Y"}}}, op_type, {{"Out@GRAD", {"DOut"}}, {"X", {"X"}}, {"Y", {"Y"}}},
{{"X@GRAD", {"DX"}}, {"Y@GRAD", {"DY"}}}, attrs); {{"X@GRAD", {"DX"}}, {"Y@GRAD", {"DY"}}}, attrs);
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
op->Run(*scope, place); op->Run(*scope, place);
ctx.Wait();
std::vector<T> dx_vec; std::vector<T> dx_vec;
TensorToVector(*tensor_dx, ctx, &dx_vec); TensorToVector(*tensor_dx, ctx, &dx_vec);
...@@ -179,3 +181,9 @@ TEST(elementwise_sub_grad, NPU) { ...@@ -179,3 +181,9 @@ TEST(elementwise_sub_grad, NPU) {
p::NPUDeviceContext ctx(p::NPUPlace(0)); p::NPUDeviceContext ctx(p::NPUPlace(0));
CompareGrad<float>(&scope, ctx, "elementwise_sub_grad"); CompareGrad<float>(&scope, ctx, "elementwise_sub_grad");
} }
TEST(elementwise_add_grad, NPU) {
f::Scope scope;
p::NPUDeviceContext ctx(p::NPUPlace(0));
CompareGrad<float>(&scope, ctx, "elementwise_add_grad");
}
...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef PADDLE_WITH_ASCEND_CL
#include <memory> #include <memory>
#include <string> #include <string>
...@@ -24,7 +23,7 @@ namespace operators { ...@@ -24,7 +23,7 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename DeviceContext, typename T> template <typename T>
class ElementwiseSubNPUKernel : public framework::OpKernel<T> { class ElementwiseSubNPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -43,7 +42,7 @@ class ElementwiseSubNPUKernel : public framework::OpKernel<T> { ...@@ -43,7 +42,7 @@ class ElementwiseSubNPUKernel : public framework::OpKernel<T> {
} }
}; };
template <typename DeviceContext, typename T> template <typename T>
class ElementwiseSubGradNPUKernel : public framework::OpKernel<T> { class ElementwiseSubGradNPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -51,8 +50,9 @@ class ElementwiseSubGradNPUKernel : public framework::OpKernel<T> { ...@@ -51,8 +50,9 @@ class ElementwiseSubGradNPUKernel : public framework::OpKernel<T> {
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y")); auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
dx->mutable_data<T>(ctx.GetPlace()); auto stream =
dy->mutable_data<T>(ctx.GetPlace()); ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
// NOTE(zhiqiu): It seems Ascend Sub follow the broadcast sematics with // NOTE(zhiqiu): It seems Ascend Sub follow the broadcast sematics with
// default axis=-1? // default axis=-1?
...@@ -66,9 +66,8 @@ class ElementwiseSubGradNPUKernel : public framework::OpKernel<T> { ...@@ -66,9 +66,8 @@ class ElementwiseSubGradNPUKernel : public framework::OpKernel<T> {
// 1. [2, 3, 5] => [3, 5], ReduceSumD on axis = 0, keep_dims = false. // 1. [2, 3, 5] => [3, 5], ReduceSumD on axis = 0, keep_dims = false.
// 2. [3, 5] => [1, 5], ReduceSumD on axis = 0, keep_dims = true. // 2. [3, 5] => [1, 5], ReduceSumD on axis = 0, keep_dims = true.
auto stream = if (dx) {
ctx.template device_context<paddle::platform::NPUDeviceContext>() dx->mutable_data<T>(ctx.GetPlace());
.stream();
// For dx // For dx
// stage 1 // stage 1
auto reduce_ndim = dout->dims().size() - dx->dims().size(); auto reduce_ndim = dout->dims().size() - dx->dims().size();
...@@ -105,16 +104,19 @@ class ElementwiseSubGradNPUKernel : public framework::OpKernel<T> { ...@@ -105,16 +104,19 @@ class ElementwiseSubGradNPUKernel : public framework::OpKernel<T> {
} else { } else {
framework::TensorCopySync(*tmp_dout, ctx.GetPlace(), dx); framework::TensorCopySync(*tmp_dout, ctx.GetPlace(), dx);
} }
}
if (dy) {
dy->mutable_data<T>(ctx.GetPlace());
// For dy // For dy
// stage 1 // stage 1
reduce_ndim = dout->dims().size() - dy->dims().size(); auto reduce_ndim = dout->dims().size() - dy->dims().size();
axes.clear(); std::vector<int> axes;
for (auto i = 0; i < reduce_ndim; ++i) { for (auto i = 0; i < reduce_ndim; ++i) {
axes.push_back(i); axes.push_back(i);
} }
tmp_dout = const_cast<Tensor*>(dout); Tensor* tmp_dout = const_cast<Tensor*>(dout);
Tensor reduced_dy(dy->type()); Tensor reduced_dy(dy->type());
Tensor reduced_dout(dy->type());
if (axes.size() != 0) { if (axes.size() != 0) {
std::vector<int64_t> reduced_dout_dims; std::vector<int64_t> reduced_dout_dims;
...@@ -150,22 +152,18 @@ class ElementwiseSubGradNPUKernel : public framework::OpKernel<T> { ...@@ -150,22 +152,18 @@ class ElementwiseSubGradNPUKernel : public framework::OpKernel<T> {
auto runner = NpuOpRunner("Neg", {*tmp_dy}, {*dy}, {}); auto runner = NpuOpRunner("Neg", {*tmp_dy}, {*dy}, {});
runner.Run(stream); runner.Run(stream);
} }
}
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL(elementwise_sub, ops::ElementwiseSubNPUKernel<float>,
ops::ElementwiseSubNPUKernel<plat::float16>);
REGISTER_OP_NPU_KERNEL( REGISTER_OP_NPU_KERNEL(elementwise_sub_grad,
elementwise_sub, ops::ElementwiseSubGradNPUKernel<float>,
ops::ElementwiseSubNPUKernel<paddle::platform::NPUDeviceContext, float>, ops::ElementwiseSubGradNPUKernel<plat::float16>);
ops::ElementwiseSubNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_NPU_KERNEL(
elementwise_sub_grad,
ops::ElementwiseSubGradNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::ElementwiseSubGradNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册