diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index 0de97a62ac0e1e574ccbdfaf4c993366f1a0d77f..7344bcfb6b8b27384ccc4198c55868efa4455a14 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -97,6 +97,36 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place, "Copy from %s to %s is not supported.", src_place, dst_place)); } #endif +#ifdef PADDLE_WITH_ASCEND_CL + // TODO(zhiqiu): handle different condition like CUDA code below + else if (platform::is_npu_place(src_place) && // NOLINT + platform::is_cpu_place(dst_place)) { + auto stream = reinterpret_cast(ctx).stream(); + memory::Copy(BOOST_GET_CONST(platform::CPUPlace, dst_place), dst_ptr, + BOOST_GET_CONST(platform::NPUPlace, src_place), src_ptr, size, stream); + } + else if (platform::is_cpu_place(src_place) && // NOLINT + platform::is_npu_place(dst_place)) { + auto stream = reinterpret_cast(ctx).stream(); + memory::Copy(BOOST_GET_CONST(platform::NPUPlace, dst_place), dst_ptr, + BOOST_GET_CONST(platform::CPUPlace, src_place), src_ptr, size, stream); + } + else if (platform::is_npu_place(src_place) && // NOLINT + platform::is_npu_place(dst_place)) { + if (src_ptr == dst_ptr) { + VLOG(3) << "Skip copy the same data async from " << src_place << " to " + << dst_place; + return; + } + auto stream = reinterpret_cast(ctx).stream(); + memory::Copy(BOOST_GET_CONST(platform::NPUPlace, dst_place), dst_ptr, + BOOST_GET_CONST(platform::NPUPlace, src_place), src_ptr, size, stream); + } + else { // NOLINT + PADDLE_THROW(platform::errors::Unimplemented( + "Copy from %s to %s is not supported.", src_place, dst_place)); + } +#endif #ifdef PADDLE_WITH_CUDA else if (platform::is_cuda_pinned_place(src_place) && // NOLINT platform::is_cuda_pinned_place(dst_place)) { @@ -304,6 +334,32 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place, "Copy from %s to %s is not supported.", src_place, dst_place)); } #endif +#ifdef PADDLE_WITH_ASCEND_CL + else if (platform::is_npu_place(src_place) && // NOLINT + platform::is_cpu_place(dst_place)) { + memory::Copy(BOOST_GET_CONST(platform::CPUPlace, dst_place), dst_ptr, + BOOST_GET_CONST(platform::NPUPlace, src_place), src_ptr, size, nullptr); + } + else if (platform::is_cpu_place(src_place) && // NOLINT + platform::is_npu_place(dst_place)) { + memory::Copy(BOOST_GET_CONST(platform::NPUPlace, dst_place), dst_ptr, + BOOST_GET_CONST(platform::CPUPlace, src_place), src_ptr, size, nullptr); + } + else if (platform::is_npu_place(src_place) && // NOLINT + platform::is_npu_place(dst_place)) { + if (src_ptr == dst_ptr) { + VLOG(3) << "Skip copy the same data sync from " << src_place << " to " + << dst_place; + return; + } + memory::Copy(BOOST_GET_CONST(platform::NPUPlace, dst_place), dst_ptr, + BOOST_GET_CONST(platform::NPUPlace, src_place), src_ptr, size, nullptr); + } + else { // NOLINT + PADDLE_THROW(platform::errors::Unimplemented( + "Copy from %s to %s is not supported.", src_place, dst_place)); + } +#endif #ifdef PADDLE_WITH_CUDA else if (platform::is_cuda_pinned_place(src_place) && // NOLINT platform::is_cuda_pinned_place(dst_place)) { @@ -433,10 +489,9 @@ class AnyVisitor : public boost::static_visitor { bool GetResult(const framework::Tensor& out, const platform::NPUPlace& npu) const { - PADDLE_THROW(platform::errors::Unimplemented( - "Not supported on place (%s) ", - npu)); - //return GetResultHelper(out, npu); + PADDLE_THROW( + platform::errors::Unimplemented("Not supported on place (%s) ", npu)); + // return GetResultHelper(out, npu); } bool GetResult(const framework::Tensor& out, @@ -642,7 +697,7 @@ struct BothFalseVisitor : public boost::static_visitor<> { } void VisitorImpl(const platform::NPUPlace& npu) const { - //TODO(zhiqiu) + // TODO(zhiqiu) } void VisitorImpl(const platform::CPUPlace& cpu) const { diff --git a/paddle/fluid/operators/elementwise/CMakeLists.txt b/paddle/fluid/operators/elementwise/CMakeLists.txt index d3f7290aada882bca8d0dde6106aaa065c809593..1309f1d457ad4d1ff13c9d91d9f6a94170f80626 100644 --- a/paddle/fluid/operators/elementwise/CMakeLists.txt +++ b/paddle/fluid/operators/elementwise/CMakeLists.txt @@ -8,4 +8,4 @@ register_operators(DEPS op_version_registry) cc_test(test_elementwise_add_op_inplace SRCS test_elementwise_add_op_inplace.cc DEPS op_registry elementwise_add_op scope device_context enforce executor) cc_test(test_elementwise_div_grad_grad SRCS test_elementwise_div_grad_grad.cc DEPS op_registry elementwise_div_op scope device_context enforce executor) cc_test(test_elementwise_add_grad_grad SRCS test_elementwise_add_grad_grad.cc DEPS op_registry elementwise_add_op scope device_context enforce executor) -cc_test(elementwise_add_op_npu_test SRCS elementwise_add_op_npu_test.cc DEPS op_registry elementwise_add_op scope device_context enforce executor) +cc_test(elementwise_op_npu_test SRCS elementwise_op_npu_test.cc DEPS op_registry elementwise_add_op elementwise_sub_op scope device_context enforce executor) diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op_npu.cc b/paddle/fluid/operators/elementwise/elementwise_add_op_npu.cc index 6e48b84c9c6470a5857eb5d75e9d690f0736dd78..1e7e5e02c0181f8828a59b9403ac24f40347f8b6 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op_npu.cc +++ b/paddle/fluid/operators/elementwise/elementwise_add_op_npu.cc @@ -29,17 +29,9 @@ class ElementwiseAddNPUKernel : public framework::OpKernel { auto* x = ctx.Input("X"); auto* y = ctx.Input("Y"); auto* out = ctx.Output("Out"); - out->mutable_data(ctx.GetPlace()); - // TODO(zhiqiu): get the attr infomation of Ascend op and - // convert paddle AttributeMap to Ascend attrs. - // Ascend op add has no attribute ? - // int axis = ctx.Attr("axis"); - - // NOTE(zhiqiu): the order of inputs and outputs is important auto runner = NpuOpRunner("Add", {*x, *y}, {*out}, {}); - auto stream = ctx.template device_context() .stream(); diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op_npu_test.cc b/paddle/fluid/operators/elementwise/elementwise_add_op_npu_test.cc deleted file mode 100644 index adc31cae0ee76de9569b4b277f35a56201e96104..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/elementwise/elementwise_add_op_npu_test.cc +++ /dev/null @@ -1,87 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifndef _WIN32 -#include -#endif - -#include -#include // NOLINT -#include - -#include "gtest/gtest.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/dropout_op.h" -#include "paddle/fluid/operators/math/math_function.h" -#include "paddle/fluid/string/printf.h" - -namespace f = paddle::framework; -namespace p = paddle::platform; -namespace m = paddle::operators::math; - -USE_OP(elementwise_add); -USE_OP_DEVICE_KERNEL(elementwise_add, NPU); - -void Compare(f::Scope* scope, const p::DeviceContext& ctx) { - // init - auto x = scope->Var("X"); - auto tensor_x = x->GetMutable(); - - auto y = scope->Var("Y"); - auto tensor_y = y->GetMutable(); - - std::vector init; - for (int64_t i = 0; i < 10 * 10; ++i) { - init.push_back(1.0); - } - - TensorFromVector(init, ctx, tensor_x); - tensor_x->Resize({10, 10}); - TensorFromVector(init, ctx, tensor_y); - tensor_y->Resize({10, 10}); - - ctx.Wait(); - - auto place = ctx.GetPlace(); - auto out = scope->Var("Out"); - auto tensor_out = out->GetMutable(); - tensor_out->Resize({10, 10}); - tensor_out->mutable_data(place); // allocate - - // run - f::AttributeMap attrs; - auto op = - f::OpRegistry::CreateOp("elementwise_add", {{"X", {"X"}}, {"Y", {"Y"}}}, - {{"Out", {"Out"}}}, attrs); - - op->Run(*scope, place); - - std::vector out_vec; - TensorToVector(*tensor_out, ctx, &out_vec); - - ctx.Wait(); - - EXPECT_EQ(out_vec.size(), init.size()); - for (uint32_t i = 0; i < out_vec.size(); i++) { - EXPECT_EQ(out_vec[i], 2.0); - } -} - -TEST(elementwise_add, NPU) { - f::Scope scope; - p::NPUDeviceContext ctx(p::NPUPlace(0)); - Compare(&scope, ctx); -} diff --git a/paddle/fluid/operators/elementwise/elementwise_op_npu_test.cc b/paddle/fluid/operators/elementwise/elementwise_op_npu_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..0cb8fd1c5781f4a154782e744ab4b0ccd0e92d9a --- /dev/null +++ b/paddle/fluid/operators/elementwise/elementwise_op_npu_test.cc @@ -0,0 +1,181 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef _WIN32 +#include +#endif + +#include +#include // NOLINT +#include + +#include "gtest/gtest.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/operators/dropout_op.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/string/printf.h" + +namespace f = paddle::framework; +namespace p = paddle::platform; +namespace m = paddle::operators::math; + +USE_OP(elementwise_add); +USE_OP_DEVICE_KERNEL(elementwise_add, NPU); +USE_OP(elementwise_sub); +USE_OP_DEVICE_KERNEL(elementwise_sub, NPU); + +template +void Compare(f::Scope* scope, const p::DeviceContext& ctx, + std::string op_type) { + // init + auto x = scope->Var("X"); + auto tensor_x = x->GetMutable(); + + auto y = scope->Var("Y"); + auto tensor_y = y->GetMutable(); + + std::vector init_x; + for (int64_t i = 0; i < 10 * 10; ++i) { + init_x.push_back(static_cast(1.0)); + } + + std::vector init_y; + for (int64_t i = 0; i < 10 * 10; ++i) { + init_y.push_back(static_cast(2.0)); + } + + TensorFromVector(init_x, ctx, tensor_x); + tensor_x->Resize({10, 10}); + TensorFromVector(init_y, ctx, tensor_y); + tensor_y->Resize({10, 10}); + + ctx.Wait(); + + auto place = ctx.GetPlace(); + auto out = scope->Var("Out"); + auto tensor_out = out->GetMutable(); + + // run + f::AttributeMap attrs; + auto op = f::OpRegistry::CreateOp(op_type, {{"X", {"X"}}, {"Y", {"Y"}}}, + {{"Out", {"Out"}}}, attrs); + + op->Run(*scope, place); + + std::vector out_vec; + TensorToVector(*tensor_out, ctx, &out_vec); + + ctx.Wait(); + float expected; + if (op_type == "elementwise_add") { + expected = 3.0; + } else if (op_type == "elementwise_sub") { + expected = -1.0; + } + EXPECT_EQ(out_vec.size(), init_x.size()); + for (uint32_t i = 0; i < out_vec.size(); i++) { + EXPECT_EQ(out_vec[i], static_cast(expected)); + } +} + +template +void CompareGrad(f::Scope* scope, const p::DeviceContext& ctx, + std::string op_type) { + // init + auto dout = scope->Var("DOut"); + auto tensor_dout = dout->GetMutable(); + tensor_dout->Resize({2, 3, 5}); + + auto x = scope->Var("X"); + auto tensor_x = x->GetMutable(); + tensor_x->Resize({2, 3, 5}); + + auto y = scope->Var("Y"); + auto tensor_y = y->GetMutable(); + tensor_y->Resize({1, 5}); + + auto dx = scope->Var("DX"); + auto tensor_dx = dx->GetMutable(); + + auto dy = scope->Var("DY"); + auto tensor_dy = dy->GetMutable(); + + std::vector init_dout; + for (int64_t i = 0; i < tensor_dout->numel(); ++i) { + init_dout.push_back(static_cast(1.0)); + } + + TensorFromVector(init_dout, ctx, tensor_dout); + tensor_dout->Resize({2, 3, 5}); + + ctx.Wait(); + + // run + f::AttributeMap attrs; + auto op = f::OpRegistry::CreateOp(op_type, + {{"Out@GRAD", {"DOut"}}, {"X", {"X"}}, {"Y", {"Y"}}}, + {{"X@GRAD", {"DX"}}, {"Y@GRAD", {"DY"}}}, attrs); + + auto place = ctx.GetPlace(); + op->Run(*scope, place); + + std::vector dx_vec; + TensorToVector(*tensor_dx, ctx, &dx_vec); + + std::vector dy_vec; + TensorToVector(*tensor_dy, ctx, &dy_vec); + + ctx.Wait(); + float expected_x, expected_y; + if (op_type == "elementwise_add_grad") { + expected_x = 1.0; + expected_y = 6.0; + } else if (op_type == "elementwise_sub_grad") { + expected_x = 1.0; + expected_y = -6.0; + } + + for (uint32_t i = 0; i < dx_vec.size(); i++) { + EXPECT_EQ(dx_vec[i], static_cast(expected_x)); + } + for (uint32_t i = 0; i < dy_vec.size(); i++) { + EXPECT_EQ(dy_vec[i], static_cast(expected_y)); + } +} + +TEST(elementwise_add, NPU_fp32) { + f::Scope scope; + p::NPUDeviceContext ctx(p::NPUPlace(0)); + Compare(&scope, ctx, "elementwise_add"); +} + +TEST(elementwise_sub, NPU_fp32) { + f::Scope scope; + p::NPUDeviceContext ctx(p::NPUPlace(0)); + Compare(&scope, ctx, "elementwise_sub"); +} + +TEST(elementwise_sub, NPU_fp16) { + f::Scope scope; + p::NPUDeviceContext ctx(p::NPUPlace(0)); + Compare(&scope, ctx, "elementwise_sub"); +} + +TEST(elementwise_sub_grad, NPU) { + f::Scope scope; + p::NPUDeviceContext ctx(p::NPUPlace(0)); + CompareGrad(&scope, ctx, "elementwise_sub_grad"); +} diff --git a/paddle/fluid/operators/elementwise/elementwise_sub_op_npu.cc b/paddle/fluid/operators/elementwise/elementwise_sub_op_npu.cc new file mode 100644 index 0000000000000000000000000000000000000000..c3cf76451f62fe23dc88b77f4385f928a4910dbb --- /dev/null +++ b/paddle/fluid/operators/elementwise/elementwise_sub_op_npu.cc @@ -0,0 +1,171 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_ASCEND_CL +#include +#include + +#include "paddle/fluid/operators/elementwise/elementwise_sub_op.h" +#include "paddle/fluid/operators/npu_op_runner.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class ElementwiseSubNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* out = ctx.Output("Out"); + + out->mutable_data(ctx.GetPlace()); + + auto runner = NpuOpRunner("Sub", {*x, *y}, {*out}, {}); + + auto stream = + ctx.template device_context() + .stream(); + runner.Run(stream); + } +}; + +template +class ElementwiseSubGradNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dy = ctx.Output(framework::GradVarName("Y")); + + dx->mutable_data(ctx.GetPlace()); + dy->mutable_data(ctx.GetPlace()); + + // NOTE(zhiqiu): It seems Ascend Sub follow the broadcast sematics with + // default axis=-1? + // So, the sub_grad should do reduce if needed. + // For example, the shape of each variable in elementwise_sub: + // x, dx: [2, 3, 5] + // y, dy: [1, 5] + // out, dout: [2, 3, 5] + // Then, out = x - y => dx = dout, dy = -dout + // And, the shape of dy can be computed by two stages reduce, + // 1. [2, 3, 5] => [3, 5], ReduceSumD on axis = 0, keep_dims = false. + // 2. [3, 5] => [1, 5], ReduceSumD on axis = 0, keep_dims = true. + + auto stream = + ctx.template device_context() + .stream(); + // For dx + // stage 1 + auto reduce_ndim = dout->dims().size() - dx->dims().size(); + std::vector axes; + for (auto i = 0; i < reduce_ndim; ++i) { + axes.push_back(i); + } + Tensor* tmp_dout = const_cast(dout); + Tensor reduced_dout(dx->type()); + if (axes.size() != 0) { + std::vector reduced_dout_dims; + for (auto i = reduce_ndim; i < dout->dims().size(); ++i) { + reduced_dout_dims.push_back(dout->dims()[i]); + } + reduced_dout.Resize(framework::make_ddim(reduced_dout_dims)); + reduced_dout.mutable_data(ctx.GetPlace()); + auto runner = NpuOpRunner("ReduceSumD", {*dout}, {reduced_dout}, + {{"axes", axes}, {"keep_dims", false}}); + runner.Run(stream); + tmp_dout = &reduced_dout; + } + + // stage 2 + axes.clear(); + for (auto i = 0; i < dx->dims().size(); ++i) { + if (dx->dims()[i] == 1) { + axes.push_back(i); + } + } + if (axes.size() != 0) { + auto runner = NpuOpRunner("ReduceSumD", {*tmp_dout}, {*dx}, + {{"axes", axes}, {"keep_dims", true}}); + runner.Run(stream); + } else { + framework::TensorCopySync(*tmp_dout, ctx.GetPlace(), dx); + } + + // For dy + // stage 1 + reduce_ndim = dout->dims().size() - dy->dims().size(); + axes.clear(); + for (auto i = 0; i < reduce_ndim; ++i) { + axes.push_back(i); + } + tmp_dout = const_cast(dout); + Tensor reduced_dy(dy->type()); + + if (axes.size() != 0) { + std::vector reduced_dout_dims; + for (auto i = reduce_ndim; i < dout->dims().size(); ++i) { + reduced_dout_dims.push_back(dout->dims()[i]); + } + reduced_dout.Resize(framework::make_ddim(reduced_dout_dims)); + reduced_dout.mutable_data(ctx.GetPlace()); + auto runner = NpuOpRunner("ReduceSumD", {*dout}, {reduced_dout}, + {{"axes", axes}, {"keep_dims", false}}); + runner.Run(stream); + tmp_dout = &reduced_dout; + } + + // stage 2 + axes.clear(); + Tensor* tmp_dy = tmp_dout; + for (auto i = 0; i < dy->dims().size(); ++i) { + if (dy->dims()[i] == 1) { + axes.push_back(i); + } + } + if (axes.size() != 0) { + reduced_dy.Resize(dy->dims()); + reduced_dy.mutable_data(ctx.GetPlace()); + auto runner = NpuOpRunner("ReduceSumD", {*tmp_dout}, {reduced_dy}, + {{"axes", axes}, {"keep_dims", true}}); + runner.Run(stream); + tmp_dy = &reduced_dy; + } + + // stage 3, negative + auto runner = NpuOpRunner("Neg", {*tmp_dy}, {*dy}, {}); + runner.Run(stream); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_NPU_KERNEL( + elementwise_sub, + ops::ElementwiseSubNPUKernel, + ops::ElementwiseSubNPUKernel); + +REGISTER_OP_NPU_KERNEL( + elementwise_sub_grad, + ops::ElementwiseSubGradNPUKernel, + ops::ElementwiseSubGradNPUKernel); +#endif diff --git a/paddle/fluid/operators/npu_op_runner.cc b/paddle/fluid/operators/npu_op_runner.cc index 5a9f8008e7be8d4c24a79f6679bfb7db632ffb61..7af6de5224145b991b9f4f17eebbf4c3748fac59 100644 --- a/paddle/fluid/operators/npu_op_runner.cc +++ b/paddle/fluid/operators/npu_op_runner.cc @@ -253,7 +253,7 @@ void NpuOpRunner::Run(aclrtStream stream) { input_buffers_.data(), output_descs_.size(), output_descs_.data(), output_buffers_.data(), attr_, ACL_ENGINE_SYS, ACL_COMPILE_SYS, NULL, stream); - VLOG(4) << "after aclopCompileAndExecute"; + VLOG(4) << "after aclopCompileAndExecute: " << ret; PADDLE_ENFORCE_NPU_SUCCESS(ret); } } // namespace operators