未验证 提交 342252c9 编写于 作者: L Leo Chen 提交者: GitHub

[NPU] change transpose to transpose2 (#31734)

* change transpose to transpose2

* fix bug
上级 7b450e78
...@@ -9,14 +9,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -9,14 +9,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef PADDLE_WITH_ASCEND_CL #include <iostream>
#include <memory> #include <memory>
#include <string> #include <string>
#include <iostream>
#include "paddle/fluid/operators/npu_op_runner.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/expand_op.h" #include "paddle/fluid/operators/expand_op.h"
#include "paddle/fluid/operators/npu_op_runner.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -31,53 +30,52 @@ class TransposeNPUKernel : public framework::OpKernel<T> { ...@@ -31,53 +30,52 @@ class TransposeNPUKernel : public framework::OpKernel<T> {
framework::NPUAttributeMap attr_input = {{"perm", axis}}; framework::NPUAttributeMap attr_input = {{"perm", axis}};
out->mutable_data<T>(ctx.device_context().GetPlace()); out->mutable_data<T>(ctx.device_context().GetPlace());
auto runner = NpuOpRunner("TransposeD", {*x}, {*out}, attr_input); auto runner = NpuOpRunner("TransposeD", {*x}, {*out}, attr_input);
auto stream = ctx.template device_context<paddle::platform::NPUDeviceContext>().stream(); auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
runner.Run(stream); runner.Run(stream);
} }
}; };
template <typename T> template <typename T>
class TransposeGradNPUKernel : public framework::OpKernel<T> { class TransposeGradNPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* out_grad = ctx.Input<framework::LoDTensor>(framework::GradVarName("Out")); auto* out_grad =
auto* x_grad = ctx.Output<framework::LoDTensor>(framework::GradVarName("X")); ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto* x_grad =
ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
std::vector<int> axis = ctx.Attr<std::vector<int>>("axis"); std::vector<int> axis = ctx.Attr<std::vector<int>>("axis");
std::vector<int> reversed_axis(axis); std::vector<int> reversed_axis(axis);
for (size_t i = 0; i < axis.size(); i++) { for (size_t i = 0; i < axis.size(); i++) {
reversed_axis[axis[i]] = i; reversed_axis[axis[i]] = i;
} }
x_grad->mutable_data<T>(ctx.GetPlace());
framework::NPUAttributeMap attr_input = {{"perm", reversed_axis}}; framework::NPUAttributeMap attr_input = {{"perm", reversed_axis}};
auto runner = NpuOpRunner("TransposeD", {*out_grad}, {*x_grad}, attr_input); auto runner = NpuOpRunner("TransposeD", {*out_grad}, {*x_grad}, attr_input);
auto stream = ctx.template device_context<paddle::platform::NPUDeviceContext>().stream(); auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
runner.Run(stream); runner.Run(stream);
} }
}; };
} } // namespace operators
} } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_NPU_KERNEL(transpose, REGISTER_OP_NPU_KERNEL(
transpose2,
ops::TransposeNPUKernel<paddle::platform::NPUDeviceContext, float>, ops::TransposeNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::TransposeNPUKernel<paddle::platform::NPUDeviceContext, paddle::platform::float16>, ops::TransposeNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>,
ops::TransposeNPUKernel<paddle::platform::NPUDeviceContext, int>, ops::TransposeNPUKernel<paddle::platform::NPUDeviceContext, int>,
ops::TransposeNPUKernel<paddle::platform::NPUDeviceContext, uint8_t>, ops::TransposeNPUKernel<paddle::platform::NPUDeviceContext, uint8_t>,
ops::TransposeNPUKernel<paddle::platform::NPUDeviceContext, int8_t> ops::TransposeNPUKernel<paddle::platform::NPUDeviceContext, int8_t>);
);
REGISTER_OP_NPU_KERNEL(transpose_grad, REGISTER_OP_NPU_KERNEL(transpose2_grad, ops::TransposeGradNPUKernel<float>,
ops::TransposeGradNPUKernel<float>,
ops::TransposeGradNPUKernel<paddle::platform::float16>, ops::TransposeGradNPUKernel<paddle::platform::float16>,
ops::TransposeGradNPUKernel<int>, ops::TransposeGradNPUKernel<int>,
ops::TransposeGradNPUKernel<uint8_t>, ops::TransposeGradNPUKernel<uint8_t>,
ops::TransposeGradNPUKernel<int8_t> ops::TransposeGradNPUKernel<int8_t>);
);
#endif
...@@ -13,12 +13,12 @@ limitations under the License. */ ...@@ -13,12 +13,12 @@ limitations under the License. */
#include <unistd.h> #include <unistd.h>
#endif #endif
#include <string>
#include <cmath> #include <cmath>
#include <iostream>
#include <numeric>
#include <string>
#include <thread> // NOLINT #include <thread> // NOLINT
#include <vector> #include <vector>
#include <numeric>
#include <iostream>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -32,17 +32,18 @@ namespace f = paddle::framework; ...@@ -32,17 +32,18 @@ namespace f = paddle::framework;
namespace p = paddle::platform; namespace p = paddle::platform;
namespace m = paddle::operators::math; namespace m = paddle::operators::math;
USE_OP(transpose); USE_OP(transpose2);
USE_OP_DEVICE_KERNEL(transpose, NPU); USE_OP_DEVICE_KERNEL(transpose2, NPU);
template <typename T> template <typename T>
void Compare(f::Scope* scope, const p::DeviceContext& ctx) { void Compare(f::Scope* scope, const p::DeviceContext& ctx) {
// init // init
auto x = scope->Var("X"); auto x = scope->Var("X");
auto out = scope->Var("Out"); auto out = scope->Var("Out");
auto xshape = scope->Var("XShape");
auto* x_t = x->GetMutable<f::LoDTensor>(); auto* x_t = x->GetMutable<f::LoDTensor>();
auto* out_t = out->GetMutable<f::LoDTensor>(); auto* out_t = out->GetMutable<f::LoDTensor>();
auto* xshape_t = xshape->GetMutable<f::LoDTensor>();
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
int dim0 = 2; int dim0 = 2;
...@@ -54,12 +55,13 @@ void Compare(f::Scope* scope, const p::DeviceContext& ctx) { ...@@ -54,12 +55,13 @@ void Compare(f::Scope* scope, const p::DeviceContext& ctx) {
ctx.Wait(); ctx.Wait();
out_t->mutable_data<T>(place); out_t->mutable_data<T>(place);
ctx.Wait(); ctx.Wait();
f::AttributeMap attrs = { xshape_t->Resize({dim0, dim1});
{"axis", std::vector<int>({1, 0})}, xshape_t->mutable_data<T>(place);
{"data_format", std::string("AnyLayout")} f::AttributeMap attrs = {{"axis", std::vector<int>({1, 0})},
}; {"data_format", std::string("AnyLayout")}};
auto op = f::OpRegistry::CreateOp("transpose", {{"X", {"X"}}}, auto op = f::OpRegistry::CreateOp("transpose2", {{"X", {"X"}}},
{{"Out", {"Out"}}}, attrs); {{"Out", {"Out"}}, {"XShape", {"XShape"}}},
attrs);
ctx.Wait(); ctx.Wait();
op->Run(*scope, place); op->Run(*scope, place);
ctx.Wait(); ctx.Wait();
...@@ -76,42 +78,37 @@ void Compare(f::Scope* scope, const p::DeviceContext& ctx) { ...@@ -76,42 +78,37 @@ void Compare(f::Scope* scope, const p::DeviceContext& ctx) {
EXPECT_EQ(out_v[5], 5); EXPECT_EQ(out_v[5], 5);
} }
template <typename T> template <typename T>
void CompareGrad(f::Scope* scope, const p::DeviceContext& ctx) { void CompareGrad(f::Scope* scope, const p::DeviceContext& ctx) {
// init // init
auto x = scope->Var("X"); auto xshape = scope->Var("XShape");
auto x_grad = scope->Var("X@GRAD"); auto x_grad = scope->Var("X@GRAD");
auto out = scope->Var("Out");
auto out_grad = scope->Var("Out@GRAD"); auto out_grad = scope->Var("Out@GRAD");
auto* x_grad_t = x_grad->GetMutable<f::LoDTensor>(); auto* x_grad_t = x_grad->GetMutable<f::LoDTensor>();
auto* x_t = x->GetMutable<f::LoDTensor>(); auto* xshape_t = xshape->GetMutable<f::LoDTensor>();
auto* out_grad_t = out_grad->GetMutable<f::LoDTensor>(); auto* out_grad_t = out_grad->GetMutable<f::LoDTensor>();
auto* out_t = out->GetMutable<f::LoDTensor>();
int dim0 = 2; int dim0 = 2;
int dim1 = 3; int dim1 = 3;
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
TensorFromVector(std::vector<T>({0, 1, 2, 3, 4, 5}), ctx, out_grad_t); TensorFromVector(std::vector<T>({0, 1, 2, 3, 4, 5}), ctx, out_grad_t);
TensorFromVector(std::vector<T>({0, 1, 2, 3, 4, 5}), ctx, x_t);
ctx.Wait(); ctx.Wait();
x_grad_t->Resize({dim0, dim1}); x_grad_t->Resize({dim0, dim1});
x_t->Resize({dim0, dim1}); xshape_t->Resize(
{0, dim0,
dim1}); // NOTE(zhiqiu): 0 is needed, see its infershape function
out_grad_t->Resize({dim0, dim1}); out_grad_t->Resize({dim0, dim1});
out_t->Resize({dim0, dim1});
x_grad_t->mutable_data<T>(place); f::AttributeMap attrs = {{"axis", std::vector<int>({1, 0})},
out_t->mutable_data<T>(place); {"data_format", std::string("AnyLayout")}};
ctx.Wait();
f::AttributeMap attrs = {
{"axis", std::vector<int>({1, 0})},
{"data_format", std::string("AnyLayout")}
};
auto op = f::OpRegistry::CreateOp( auto op = f::OpRegistry::CreateOp(
"transpose_grad", "transpose2_grad", {{"Out@GRAD", {"Out@GRAD"}}, {"XShape", {"XShape"}}},
{{"Out@GRAD", {"Out@GRAD"}}, {"X", {"X"}}, {"Out", {"Out"}}},
{{"X@GRAD", {"X@GRAD"}}}, attrs); {{"X@GRAD", {"X@GRAD"}}}, attrs);
op->Run(*scope, place); op->Run(*scope, place);
ctx.Wait(); ctx.Wait();
std::vector<T> out_v; std::vector<T> out_v;
...@@ -125,19 +122,16 @@ void CompareGrad(f::Scope* scope, const p::DeviceContext& ctx) { ...@@ -125,19 +122,16 @@ void CompareGrad(f::Scope* scope, const p::DeviceContext& ctx) {
EXPECT_EQ(out_v[3], 4); EXPECT_EQ(out_v[3], 4);
EXPECT_EQ(out_v[4], 2); EXPECT_EQ(out_v[4], 2);
EXPECT_EQ(out_v[5], 5); EXPECT_EQ(out_v[5], 5);
} }
TEST(transpose2, NPU_fp32) {
TEST(transpose, NPU_fp32) {
f::Scope scope; f::Scope scope;
p::NPUDeviceContext ctx(p::NPUPlace(0)); p::NPUDeviceContext ctx(p::NPUPlace(0));
Compare<float>(&scope, ctx); Compare<float>(&scope, ctx);
} }
TEST(transpose_grad, NPU_fp32) { TEST(transpose2_grad, NPU_fp32) {
f::Scope scope; f::Scope scope;
p::NPUDeviceContext ctx(p::NPUPlace(0)); p::NPUDeviceContext ctx(p::NPUPlace(0));
CompareGrad<float>(&scope, ctx); CompareGrad<float>(&scope, ctx);
} }
...@@ -30,7 +30,7 @@ paddle.enable_static() ...@@ -30,7 +30,7 @@ paddle.enable_static()
class TestTransposeOp(OpTest): class TestTransposeOp(OpTest):
def setUp(self): def setUp(self):
self.set_npu() self.set_npu()
self.op_type = "transpose" self.op_type = "transpose2"
self.place = paddle.NPUPlace(0) self.place = paddle.NPUPlace(0)
self.init_dtype() self.init_dtype()
self.init_input_output() self.init_input_output()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册