未验证 提交 846c7e70 编写于 作者: N Nyakku Shigure 提交者: GitHub

[CodeStyle] remove crlf for cpp files (#46156)

上级 c6c9c186
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <glog/logging.h> #include <glog/logging.h>
#include <iosfwd> #include <iosfwd>
#include <string> #include <string>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/inference/analysis/argument.h" #include "paddle/fluid/inference/analysis/argument.h"
#include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/analysis/helper.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
/* /*
* AnalysisPass is a pass used to control the IR passes. * AnalysisPass is a pass used to control the IR passes.
*/ */
class AnalysisPass { class AnalysisPass {
public: public:
AnalysisPass() = default; AnalysisPass() = default;
virtual ~AnalysisPass() = default; virtual ~AnalysisPass() = default;
// Run on a single Graph. // Run on a single Graph.
void Run(Argument* argument) { RunImpl(argument); } void Run(Argument* argument) { RunImpl(argument); }
// Human-readable short representation. // Human-readable short representation.
virtual std::string repr() const = 0; virtual std::string repr() const = 0;
// Human-readable long description. // Human-readable long description.
virtual std::string description() const { return "No DOC"; } virtual std::string description() const { return "No DOC"; }
protected: protected:
// User should implement these. // User should implement these.
virtual void RunImpl(Argument* argument) = 0; virtual void RunImpl(Argument* argument) = 0;
}; };
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifndef _WIN32 #ifndef _WIN32
#include <unistd.h> #include <unistd.h>
#endif #endif
#include <stdio.h> #include <stdio.h>
#include <string> #include <string>
#include <thread> // NOLINT #include <thread> // NOLINT
#include <vector> #include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/collective/c_allgather_op.h" #include "paddle/fluid/operators/collective/c_allgather_op.h"
#include "paddle/fluid/operators/collective/c_allreduce_op.h" #include "paddle/fluid/operators/collective/c_allreduce_op.h"
#include "paddle/fluid/operators/collective/c_broadcast_op.h" #include "paddle/fluid/operators/collective/c_broadcast_op.h"
#include "paddle/fluid/operators/collective/c_reducescatter_op.h" #include "paddle/fluid/operators/collective/c_reducescatter_op.h"
#include "paddle/fluid/operators/collective/gen_hccl_id_op_helper.h" #include "paddle/fluid/operators/collective/gen_hccl_id_op_helper.h"
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#if defined(PADDLE_WITH_ASCEND_CL) #if defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/npu/hccl_helper.h" #include "paddle/fluid/platform/device/npu/hccl_helper.h"
#endif #endif
namespace f = paddle::framework; namespace f = paddle::framework;
namespace p = paddle::platform; namespace p = paddle::platform;
USE_OP(c_allgather); USE_OP(c_allgather);
USE_NO_KERNEL_OP(c_gen_hccl_id); USE_NO_KERNEL_OP(c_gen_hccl_id);
USE_NO_KERNEL_OP(c_comm_init_hccl); USE_NO_KERNEL_OP(c_comm_init_hccl);
USE_OP_DEVICE_KERNEL(c_allgather, NPU); USE_OP_DEVICE_KERNEL(c_allgather, NPU);
DECLARE_string(selected_npus); DECLARE_string(selected_npus);
template <typename T> template <typename T>
void PrintDebugInfo(const std::string preStr, const std::vector<T>& data) { void PrintDebugInfo(const std::string preStr, const std::vector<T>& data) {
std::string debugstring = ""; std::string debugstring = "";
for (auto ele : data) { for (auto ele : data) {
debugstring += std::to_string(ele) + std::string(","); debugstring += std::to_string(ele) + std::string(",");
} }
VLOG(2) << preStr << ":" << std::endl << debugstring; VLOG(2) << preStr << ":" << std::endl << debugstring;
} }
void PrepareUniqueId(f::Scope* scope, void PrepareUniqueId(f::Scope* scope,
const p::DeviceContext& ctx, const p::DeviceContext& ctx,
HcclRootInfo* hccl_id) { HcclRootInfo* hccl_id) {
int rank_id = atoi(getenv("RANK_ID")); int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID")); int device_id = atoi(getenv("DEVICE_ID"));
VLOG(2) << "rank_id = " << rank_id << "; device_id = " << device_id VLOG(2) << "rank_id = " << rank_id << "; device_id = " << device_id
<< "; rank_id = " << rank_id << "; rank_id = " << rank_id
<< "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID")); << "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID"));
std::vector<int> rank_ids{0, 1}; std::vector<int> rank_ids{0, 1};
f::AttributeMap gen_hccl_id; f::AttributeMap gen_hccl_id;
std::vector<std::string> endpointList = {"127.0.0.1:6175", "127.0.0.1:6177"}; std::vector<std::string> endpointList = {"127.0.0.1:6175", "127.0.0.1:6177"};
gen_hccl_id["rank"] = rank_id; gen_hccl_id["rank"] = rank_id;
gen_hccl_id["endpoint"] = endpointList[rank_id]; gen_hccl_id["endpoint"] = endpointList[rank_id];
std::vector<std::string> other_endpoints = { std::vector<std::string> other_endpoints = {
endpointList[rank_id == 0 ? 1 : 0]}; endpointList[rank_id == 0 ? 1 : 0]};
gen_hccl_id["other_endpoints"] = other_endpoints; gen_hccl_id["other_endpoints"] = other_endpoints;
auto out = scope->Var("Out"); auto out = scope->Var("Out");
auto id = out->GetMutable<HcclRootInfo>(); auto id = out->GetMutable<HcclRootInfo>();
VLOG(3) << "break"; VLOG(3) << "break";
auto comm_init_op = f::OpRegistry::CreateOp( auto comm_init_op = f::OpRegistry::CreateOp(
"c_gen_hccl_id", {}, {{"Out", {"Out"}}}, gen_hccl_id); "c_gen_hccl_id", {}, {{"Out", {"Out"}}}, gen_hccl_id);
VLOG(3) << "break"; VLOG(3) << "break";
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place); comm_init_op->Run(*scope, place);
ctx.Wait(); ctx.Wait();
memcpy(hccl_id, id, 1024); memcpy(hccl_id, id, 1024);
} }
void Prepare(f::Scope* scope, void Prepare(f::Scope* scope,
const p::DeviceContext& ctx, const p::DeviceContext& ctx,
HcclRootInfo* hccl_id) { HcclRootInfo* hccl_id) {
auto x = scope->Var("X"); auto x = scope->Var("X");
auto id = x->GetMutable<HcclRootInfo>(); auto id = x->GetMutable<HcclRootInfo>();
memcpy(id, hccl_id, 1024); memcpy(id, hccl_id, 1024);
int rank_id = atoi(getenv("RANK_ID")); int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID")); int device_id = atoi(getenv("DEVICE_ID"));
VLOG(2) << "rank_id = " << rank_id << "; device_id = " << device_id VLOG(2) << "rank_id = " << rank_id << "; device_id = " << device_id
<< "; rank_id = " << rank_id << "; rank_id = " << rank_id
<< "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID")); << "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID"));
// std::vector<int> rank_ids{0, 1}; // std::vector<int> rank_ids{0, 1};
f::AttributeMap comm_init_attrs; f::AttributeMap comm_init_attrs;
comm_init_attrs["ring_id"] = 0; comm_init_attrs["ring_id"] = 0;
comm_init_attrs["rank_ids"] = 2; comm_init_attrs["rank_ids"] = 2;
comm_init_attrs["rank"] = rank_id; comm_init_attrs["rank"] = rank_id;
comm_init_attrs["device_id"] = device_id; comm_init_attrs["device_id"] = device_id;
// comm_init_attrs["rank_ids"] = rank_ids; // comm_init_attrs["rank_ids"] = rank_ids;
auto comm_init_op = f::OpRegistry::CreateOp( auto comm_init_op = f::OpRegistry::CreateOp(
"c_comm_init_hccl", {{"X", {"X"}}}, {}, comm_init_attrs); "c_comm_init_hccl", {{"X", {"X"}}}, {}, comm_init_attrs);
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place); comm_init_op->Run(*scope, place);
ctx.Wait(); ctx.Wait();
} }
void TestHCCLAllGatherOp(f::Scope* scope, const p::DeviceContext& ctx) { void TestHCCLAllGatherOp(f::Scope* scope, const p::DeviceContext& ctx) {
// init // init
auto x = scope->Var("Data"); auto x = scope->Var("Data");
auto tensor_x = x->GetMutable<f::LoDTensor>(); auto tensor_x = x->GetMutable<f::LoDTensor>();
std::vector<float> init; std::vector<float> init;
int rank_id = atoi(getenv("RANK_ID")); int rank_id = atoi(getenv("RANK_ID"));
int num1 = 1; int num1 = 1;
int num2 = 4; int num2 = 4;
for (int64_t i = 0; i < num1 * num2; ++i) { for (int64_t i = 0; i < num1 * num2; ++i) {
init.push_back(1.0 + rank_id); init.push_back(1.0 + rank_id);
} }
PrintDebugInfo("input data", init); PrintDebugInfo("input data", init);
paddle::framework::TensorFromVector(init, ctx, tensor_x); paddle::framework::TensorFromVector(init, ctx, tensor_x);
tensor_x->Resize({num1, num2}); tensor_x->Resize({num1, num2});
ctx.Wait(); ctx.Wait();
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
auto out = scope->Var("OutData"); auto out = scope->Var("OutData");
auto tensor_out = out->GetMutable<f::LoDTensor>(); auto tensor_out = out->GetMutable<f::LoDTensor>();
tensor_out->Resize({num1, num2}); tensor_out->Resize({num1, num2});
tensor_out->mutable_data<float>(place); // allocate tensor_out->mutable_data<float>(place); // allocate
ctx.Wait(); ctx.Wait();
// run // run
f::AttributeMap attrs; f::AttributeMap attrs;
attrs["tag"] = std::string("tagx"); attrs["tag"] = std::string("tagx");
attrs["ring_id"] = 0; attrs["ring_id"] = 0;
attrs["nranks"] = 2; attrs["nranks"] = 2;
auto op = f::OpRegistry::CreateOp( auto op = f::OpRegistry::CreateOp(
"c_allgather", {{"X", {"Data"}}}, {{"Out", {"OutData"}}}, attrs); "c_allgather", {{"X", {"Data"}}}, {{"Out", {"OutData"}}}, attrs);
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
op->Run(*scope, place); op->Run(*scope, place);
} }
ctx.Wait(); ctx.Wait();
std::vector<float> out_vec; std::vector<float> out_vec;
paddle::framework::TensorToVector(*tensor_out, ctx, &out_vec); paddle::framework::TensorToVector(*tensor_out, ctx, &out_vec);
ctx.Wait(); ctx.Wait();
PrintDebugInfo("output data", out_vec); PrintDebugInfo("output data", out_vec);
EXPECT_EQ(out_vec.size(), init.size() * 2); EXPECT_EQ(out_vec.size(), init.size() * 2);
for (uint32_t i = 0; i < out_vec.size() / 2; i++) { for (uint32_t i = 0; i < out_vec.size() / 2; i++) {
EXPECT_EQ(out_vec[i], 1.0); EXPECT_EQ(out_vec[i], 1.0);
} }
for (uint32_t i = out_vec.size() / 2; i < out_vec.size(); i++) { for (uint32_t i = out_vec.size() / 2; i < out_vec.size(); i++) {
EXPECT_EQ(out_vec[i], 2.0); EXPECT_EQ(out_vec[i], 2.0);
} }
} }
TEST(c_allgather, NPU) { TEST(c_allgather, NPU) {
f::Scope scope; f::Scope scope;
HcclRootInfo hccl_id; HcclRootInfo hccl_id;
// only support one device, if more than one device, use first default // only support one device, if more than one device, use first default
p::NPUDeviceContext ctx(p::NPUPlace(atoi(FLAGS_selected_npus.c_str()))); p::NPUDeviceContext ctx(p::NPUPlace(atoi(FLAGS_selected_npus.c_str())));
PrepareUniqueId(&scope, ctx, &hccl_id); PrepareUniqueId(&scope, ctx, &hccl_id);
Prepare(&scope, ctx, &hccl_id); Prepare(&scope, ctx, &hccl_id);
TestHCCLAllGatherOp(&scope, ctx); TestHCCLAllGatherOp(&scope, ctx);
} }
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifndef _WIN32 #ifndef _WIN32
#include <unistd.h> #include <unistd.h>
#endif #endif
#include <stdio.h> #include <stdio.h>
#include <string> #include <string>
#include <thread> // NOLINT #include <thread> // NOLINT
#include <vector> #include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/collective/c_allgather_op.h" #include "paddle/fluid/operators/collective/c_allgather_op.h"
#include "paddle/fluid/operators/collective/c_allreduce_op.h" #include "paddle/fluid/operators/collective/c_allreduce_op.h"
#include "paddle/fluid/operators/collective/c_broadcast_op.h" #include "paddle/fluid/operators/collective/c_broadcast_op.h"
#include "paddle/fluid/operators/collective/c_reducescatter_op.h" #include "paddle/fluid/operators/collective/c_reducescatter_op.h"
#include "paddle/fluid/operators/collective/gen_hccl_id_op_helper.h" #include "paddle/fluid/operators/collective/gen_hccl_id_op_helper.h"
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#if defined(PADDLE_WITH_ASCEND_CL) #if defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/npu/hccl_helper.h" #include "paddle/fluid/platform/device/npu/hccl_helper.h"
#endif #endif
namespace f = paddle::framework; namespace f = paddle::framework;
namespace p = paddle::platform; namespace p = paddle::platform;
USE_OP(c_allreduce_max); USE_OP(c_allreduce_max);
USE_NO_KERNEL_OP(c_gen_hccl_id); USE_NO_KERNEL_OP(c_gen_hccl_id);
USE_NO_KERNEL_OP(c_comm_init_hccl); USE_NO_KERNEL_OP(c_comm_init_hccl);
USE_OP_DEVICE_KERNEL(c_allreduce_max, NPU); USE_OP_DEVICE_KERNEL(c_allreduce_max, NPU);
DECLARE_string(selected_npus); DECLARE_string(selected_npus);
template <typename T> template <typename T>
void PrintDebugInfo(const std::string preStr, const std::vector<T>& data) { void PrintDebugInfo(const std::string preStr, const std::vector<T>& data) {
std::string debugstring = ""; std::string debugstring = "";
for (auto ele : data) { for (auto ele : data) {
debugstring += std::to_string(ele) + std::string(","); debugstring += std::to_string(ele) + std::string(",");
} }
VLOG(2) << preStr << ":" << std::endl << debugstring; VLOG(2) << preStr << ":" << std::endl << debugstring;
} }
void PrepareUniqueId(f::Scope* scope, void PrepareUniqueId(f::Scope* scope,
const p::DeviceContext& ctx, const p::DeviceContext& ctx,
HcclRootInfo* hccl_id) { HcclRootInfo* hccl_id) {
int rank_id = atoi(getenv("RANK_ID")); int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID")); int device_id = atoi(getenv("DEVICE_ID"));
VLOG(2) << "rank_id = " << rank_id << "; device_id = " << device_id VLOG(2) << "rank_id = " << rank_id << "; device_id = " << device_id
<< "; rank_id = " << rank_id << "; rank_id = " << rank_id
<< "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID")); << "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID"));
std::vector<int> rank_ids{0, 1}; std::vector<int> rank_ids{0, 1};
f::AttributeMap gen_hccl_id; f::AttributeMap gen_hccl_id;
std::vector<std::string> endpointList = {"127.0.0.1:6175", "127.0.0.1:6177"}; std::vector<std::string> endpointList = {"127.0.0.1:6175", "127.0.0.1:6177"};
gen_hccl_id["rank"] = rank_id; gen_hccl_id["rank"] = rank_id;
gen_hccl_id["endpoint"] = endpointList[rank_id]; gen_hccl_id["endpoint"] = endpointList[rank_id];
std::vector<std::string> other_endpoints = { std::vector<std::string> other_endpoints = {
endpointList[rank_id == 0 ? 1 : 0]}; endpointList[rank_id == 0 ? 1 : 0]};
gen_hccl_id["other_endpoints"] = other_endpoints; gen_hccl_id["other_endpoints"] = other_endpoints;
auto out = scope->Var("Out"); auto out = scope->Var("Out");
auto id = out->GetMutable<HcclRootInfo>(); auto id = out->GetMutable<HcclRootInfo>();
VLOG(3) << "break"; VLOG(3) << "break";
auto comm_init_op = f::OpRegistry::CreateOp( auto comm_init_op = f::OpRegistry::CreateOp(
"c_gen_hccl_id", {}, {{"Out", {"Out"}}}, gen_hccl_id); "c_gen_hccl_id", {}, {{"Out", {"Out"}}}, gen_hccl_id);
VLOG(3) << "break"; VLOG(3) << "break";
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place); comm_init_op->Run(*scope, place);
ctx.Wait(); ctx.Wait();
memcpy(hccl_id, id, 1024); memcpy(hccl_id, id, 1024);
} }
void Prepare(f::Scope* scope, void Prepare(f::Scope* scope,
const p::DeviceContext& ctx, const p::DeviceContext& ctx,
HcclRootInfo* hccl_id) { HcclRootInfo* hccl_id) {
auto x = scope->Var("X"); auto x = scope->Var("X");
auto id = x->GetMutable<HcclRootInfo>(); auto id = x->GetMutable<HcclRootInfo>();
memcpy(id, hccl_id, 1024); memcpy(id, hccl_id, 1024);
int rank_id = atoi(getenv("RANK_ID")); int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID")); int device_id = atoi(getenv("DEVICE_ID"));
VLOG(2) << "rank_id = " << rank_id << "; device_id = " << device_id VLOG(2) << "rank_id = " << rank_id << "; device_id = " << device_id
<< "; rank_id = " << rank_id << "; rank_id = " << rank_id
<< "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID")); << "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID"));
// std::vector<int> rank_ids{0, 1}; // std::vector<int> rank_ids{0, 1};
f::AttributeMap comm_init_attrs; f::AttributeMap comm_init_attrs;
comm_init_attrs["ring_id"] = 0; comm_init_attrs["ring_id"] = 0;
comm_init_attrs["rank_ids"] = 2; comm_init_attrs["rank_ids"] = 2;
comm_init_attrs["rank"] = rank_id; comm_init_attrs["rank"] = rank_id;
comm_init_attrs["device_id"] = device_id; comm_init_attrs["device_id"] = device_id;
// comm_init_attrs["rank_ids"] = rank_ids; // comm_init_attrs["rank_ids"] = rank_ids;
auto comm_init_op = f::OpRegistry::CreateOp( auto comm_init_op = f::OpRegistry::CreateOp(
"c_comm_init_hccl", {{"X", {"X"}}}, {}, comm_init_attrs); "c_comm_init_hccl", {{"X", {"X"}}}, {}, comm_init_attrs);
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place); comm_init_op->Run(*scope, place);
ctx.Wait(); ctx.Wait();
} }
void TestHCCLAllReduceOp(f::Scope* scope, const p::DeviceContext& ctx) { void TestHCCLAllReduceOp(f::Scope* scope, const p::DeviceContext& ctx) {
// init // init
auto x = scope->Var("Data"); auto x = scope->Var("Data");
auto tensor_x = x->GetMutable<f::LoDTensor>(); auto tensor_x = x->GetMutable<f::LoDTensor>();
std::vector<float> init; std::vector<float> init;
int rank_id = atoi(getenv("RANK_ID")); int rank_id = atoi(getenv("RANK_ID"));
int num1 = 100; int num1 = 100;
int num2 = 100; int num2 = 100;
for (int64_t i = 0; i < num1 * num2; ++i) { for (int64_t i = 0; i < num1 * num2; ++i) {
init.push_back(1.0 + rank_id * 3); init.push_back(1.0 + rank_id * 3);
} }
PrintDebugInfo("input data", init); PrintDebugInfo("input data", init);
paddle::framework::TensorFromVector(init, ctx, tensor_x); paddle::framework::TensorFromVector(init, ctx, tensor_x);
tensor_x->Resize({num1, num2}); tensor_x->Resize({num1, num2});
ctx.Wait(); ctx.Wait();
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
auto out = scope->Var("OutData"); auto out = scope->Var("OutData");
auto tensor_out = out->GetMutable<f::LoDTensor>(); auto tensor_out = out->GetMutable<f::LoDTensor>();
tensor_out->Resize({num1, num2}); tensor_out->Resize({num1, num2});
tensor_out->mutable_data<float>(place); // allocate tensor_out->mutable_data<float>(place); // allocate
ctx.Wait(); ctx.Wait();
// run // run
f::AttributeMap attrs; f::AttributeMap attrs;
attrs["tag"] = std::string("tagx"); attrs["tag"] = std::string("tagx");
attrs["ring_id"] = 0; attrs["ring_id"] = 0;
auto op = f::OpRegistry::CreateOp( auto op = f::OpRegistry::CreateOp(
"c_allreduce_max", {{"X", {"Data"}}}, {{"Out", {"OutData"}}}, attrs); "c_allreduce_max", {{"X", {"Data"}}}, {{"Out", {"OutData"}}}, attrs);
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
op->Run(*scope, place); op->Run(*scope, place);
} }
ctx.Wait(); ctx.Wait();
std::vector<float> out_vec; std::vector<float> out_vec;
paddle::framework::TensorToVector(*tensor_out, ctx, &out_vec); paddle::framework::TensorToVector(*tensor_out, ctx, &out_vec);
ctx.Wait(); ctx.Wait();
PrintDebugInfo("output data", out_vec); PrintDebugInfo("output data", out_vec);
EXPECT_EQ(out_vec.size(), init.size()); EXPECT_EQ(out_vec.size(), init.size());
for (uint32_t i = 0; i < out_vec.size(); i++) { for (uint32_t i = 0; i < out_vec.size(); i++) {
EXPECT_EQ(out_vec[i], 4.0); EXPECT_EQ(out_vec[i], 4.0);
} }
} }
TEST(c_allreduce_max, NPU) { TEST(c_allreduce_max, NPU) {
f::Scope scope; f::Scope scope;
HcclRootInfo hccl_id; HcclRootInfo hccl_id;
// only support one device, if more than one device, use first default // only support one device, if more than one device, use first default
p::NPUDeviceContext ctx(p::NPUPlace(atoi(FLAGS_selected_npus.c_str()))); p::NPUDeviceContext ctx(p::NPUPlace(atoi(FLAGS_selected_npus.c_str())));
PrepareUniqueId(&scope, ctx, &hccl_id); PrepareUniqueId(&scope, ctx, &hccl_id);
Prepare(&scope, ctx, &hccl_id); Prepare(&scope, ctx, &hccl_id);
TestHCCLAllReduceOp(&scope, ctx); TestHCCLAllReduceOp(&scope, ctx);
} }
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifndef _WIN32 #ifndef _WIN32
#include <unistd.h> #include <unistd.h>
#endif #endif
#include <stdio.h> #include <stdio.h>
#include <string> #include <string>
#include <thread> // NOLINT #include <thread> // NOLINT
#include <vector> #include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/collective/c_allreduce_op.h" #include "paddle/fluid/operators/collective/c_allreduce_op.h"
#include "paddle/fluid/operators/collective/gen_hccl_id_op_helper.h" #include "paddle/fluid/operators/collective/gen_hccl_id_op_helper.h"
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#if defined(PADDLE_WITH_ASCEND_CL) #if defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/npu/hccl_helper.h" #include "paddle/fluid/platform/device/npu/hccl_helper.h"
#endif #endif
// Node1: HCCL_WHITELIST_DISABLE=1 FLAGS_selected_npus=1 GLOG_v=4 RANK_ID=1 // Node1: HCCL_WHITELIST_DISABLE=1 FLAGS_selected_npus=1 GLOG_v=4 RANK_ID=1
// DEVICE_ID=1 ./paddle/fluid/operators/collective/c_allreduce_sum_op_npu_test // DEVICE_ID=1 ./paddle/fluid/operators/collective/c_allreduce_sum_op_npu_test
// Node2: HCCL_WHITELIST_DISABLE=1 FLAGS_selected_npus=0 GLOG_v=4 RANK_ID=0 // Node2: HCCL_WHITELIST_DISABLE=1 FLAGS_selected_npus=0 GLOG_v=4 RANK_ID=0
// DEVICE_ID=0 ./paddle/fluid/operators/collective/c_allreduce_sum_op_npu_test // DEVICE_ID=0 ./paddle/fluid/operators/collective/c_allreduce_sum_op_npu_test
namespace f = paddle::framework; namespace f = paddle::framework;
namespace p = paddle::platform; namespace p = paddle::platform;
USE_OP(c_allreduce_sum); USE_OP(c_allreduce_sum);
USE_NO_KERNEL_OP(c_gen_hccl_id); USE_NO_KERNEL_OP(c_gen_hccl_id);
USE_NO_KERNEL_OP(c_comm_init_hccl); USE_NO_KERNEL_OP(c_comm_init_hccl);
USE_OP_DEVICE_KERNEL(c_allreduce_sum, NPU); USE_OP_DEVICE_KERNEL(c_allreduce_sum, NPU);
DECLARE_string(selected_npus); DECLARE_string(selected_npus);
template <typename T> template <typename T>
void PrintDebugInfo(const std::string preStr, const std::vector<T>& data) { void PrintDebugInfo(const std::string preStr, const std::vector<T>& data) {
std::string debugstring = ""; std::string debugstring = "";
std::cout << preStr << ":" << std::endl << debugstring; std::cout << preStr << ":" << std::endl << debugstring;
for (auto ele : data) { for (auto ele : data) {
std::cout << ele << " "; std::cout << ele << " ";
} }
std::cout << std::endl; std::cout << std::endl;
} }
void PrepareUniqueId(f::Scope* scope, void PrepareUniqueId(f::Scope* scope,
const p::DeviceContext& ctx, const p::DeviceContext& ctx,
HcclRootInfo* hccl_id) { HcclRootInfo* hccl_id) {
int rank_id = atoi(getenv("RANK_ID")); int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID")); int device_id = atoi(getenv("DEVICE_ID"));
VLOG(2) << "rank_id = " << rank_id << "; device_id = " << device_id VLOG(2) << "rank_id = " << rank_id << "; device_id = " << device_id
<< "; rank_id = " << rank_id << "; rank_id = " << rank_id
<< "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID")); << "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID"));
std::vector<int> rank_ids{0, 1}; std::vector<int> rank_ids{0, 1};
f::AttributeMap gen_hccl_id; f::AttributeMap gen_hccl_id;
std::vector<std::string> endpointList = {"127.0.0.1:6175", "127.0.0.1:6177"}; std::vector<std::string> endpointList = {"127.0.0.1:6175", "127.0.0.1:6177"};
gen_hccl_id["rank"] = rank_id; gen_hccl_id["rank"] = rank_id;
gen_hccl_id["endpoint"] = endpointList[rank_id]; gen_hccl_id["endpoint"] = endpointList[rank_id];
std::vector<std::string> other_endpoints = { std::vector<std::string> other_endpoints = {
endpointList[rank_id == 0 ? 1 : 0]}; endpointList[rank_id == 0 ? 1 : 0]};
gen_hccl_id["other_endpoints"] = other_endpoints; gen_hccl_id["other_endpoints"] = other_endpoints;
auto out = scope->Var("Out"); auto out = scope->Var("Out");
auto id = out->GetMutable<HcclRootInfo>(); auto id = out->GetMutable<HcclRootInfo>();
VLOG(3) << "break"; VLOG(3) << "break";
auto comm_init_op = f::OpRegistry::CreateOp( auto comm_init_op = f::OpRegistry::CreateOp(
"c_gen_hccl_id", {}, {{"Out", {"Out"}}}, gen_hccl_id); "c_gen_hccl_id", {}, {{"Out", {"Out"}}}, gen_hccl_id);
VLOG(3) << "break"; VLOG(3) << "break";
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place); comm_init_op->Run(*scope, place);
ctx.Wait(); ctx.Wait();
memcpy(hccl_id, id, 1024); memcpy(hccl_id, id, 1024);
} }
void Prepare(f::Scope* scope, void Prepare(f::Scope* scope,
const p::DeviceContext& ctx, const p::DeviceContext& ctx,
HcclRootInfo* hccl_id) { HcclRootInfo* hccl_id) {
auto x = scope->Var("X"); auto x = scope->Var("X");
auto id = x->GetMutable<HcclRootInfo>(); auto id = x->GetMutable<HcclRootInfo>();
memcpy(id, hccl_id, 1024); memcpy(id, hccl_id, 1024);
int rank_id = atoi(getenv("RANK_ID")); int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID")); int device_id = atoi(getenv("DEVICE_ID"));
VLOG(2) << "rank_id = " << rank_id << "; device_id = " << device_id VLOG(2) << "rank_id = " << rank_id << "; device_id = " << device_id
<< "; rank_id = " << rank_id << "; rank_id = " << rank_id
<< "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID")); << "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID"));
// std::vector<int> rank_ids{0, 1}; // std::vector<int> rank_ids{0, 1};
f::AttributeMap comm_init_attrs; f::AttributeMap comm_init_attrs;
comm_init_attrs["ring_id"] = 0; comm_init_attrs["ring_id"] = 0;
comm_init_attrs["rank_ids"] = 2; comm_init_attrs["rank_ids"] = 2;
comm_init_attrs["rank"] = rank_id; comm_init_attrs["rank"] = rank_id;
comm_init_attrs["device_id"] = device_id; comm_init_attrs["device_id"] = device_id;
// comm_init_attrs["rank_ids"] = rank_ids; // comm_init_attrs["rank_ids"] = rank_ids;
auto comm_init_op = f::OpRegistry::CreateOp( auto comm_init_op = f::OpRegistry::CreateOp(
"c_comm_init_hccl", {{"X", {"X"}}}, {}, comm_init_attrs); "c_comm_init_hccl", {{"X", {"X"}}}, {}, comm_init_attrs);
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place); comm_init_op->Run(*scope, place);
ctx.Wait(); ctx.Wait();
} }
template <typename T> template <typename T>
void TestHCCLAllReduceOp(f::Scope* scope, void TestHCCLAllReduceOp(f::Scope* scope,
const p::DeviceContext& ctx, const p::DeviceContext& ctx,
int iter) { int iter) {
// init // init
auto x = scope->Var("Data"); auto x = scope->Var("Data");
auto tensor_x = x->GetMutable<f::LoDTensor>(); auto tensor_x = x->GetMutable<f::LoDTensor>();
int rank_id = atoi(getenv("RANK_ID")); int rank_id = atoi(getenv("RANK_ID"));
int num1 = 3; int num1 = 3;
int num2 = 128; int num2 = 128;
std::vector<T> init; std::vector<T> init;
for (int64_t i = 0; i < num1 * num2; ++i) { for (int64_t i = 0; i < num1 * num2; ++i) {
init.push_back(static_cast<T>(1.0 + rank_id)); init.push_back(static_cast<T>(1.0 + rank_id));
} }
init[0] = static_cast<T>(std::numeric_limits<float>::quiet_NaN()); init[0] = static_cast<T>(std::numeric_limits<float>::quiet_NaN());
PrintDebugInfo("input data", init); PrintDebugInfo("input data", init);
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
paddle::framework::TensorFromVector(init, ctx, tensor_x); paddle::framework::TensorFromVector(init, ctx, tensor_x);
tensor_x->Resize({num1, num2}); tensor_x->Resize({num1, num2});
ctx.Wait(); ctx.Wait();
auto out = scope->Var("OutData"); auto out = scope->Var("OutData");
auto tensor_out = out->GetMutable<f::LoDTensor>(); auto tensor_out = out->GetMutable<f::LoDTensor>();
tensor_out->Resize({num1, num2}); tensor_out->Resize({num1, num2});
tensor_out->mutable_data<T>(place); // allocate tensor_out->mutable_data<T>(place); // allocate
ctx.Wait(); ctx.Wait();
// run // run
f::AttributeMap attrs; f::AttributeMap attrs;
attrs["tag"] = std::string("tagx_" + std::to_string(iter)); attrs["tag"] = std::string("tagx_" + std::to_string(iter));
attrs["ring_id"] = 0; attrs["ring_id"] = 0;
attrs["use_calc_stream"] = 1; attrs["use_calc_stream"] = 1;
auto op = f::OpRegistry::CreateOp( auto op = f::OpRegistry::CreateOp(
"c_allreduce_sum", {{"X", {"Data"}}}, {{"Out", {"OutData"}}}, attrs); "c_allreduce_sum", {{"X", {"Data"}}}, {{"Out", {"OutData"}}}, attrs);
for (int i = 0; i < 1; i++) { for (int i = 0; i < 1; i++) {
op->Run(*scope, place); op->Run(*scope, place);
} }
ctx.Wait(); ctx.Wait();
std::vector<T> out_vec; std::vector<T> out_vec;
paddle::framework::TensorToVector(*tensor_out, ctx, &out_vec); paddle::framework::TensorToVector(*tensor_out, ctx, &out_vec);
ctx.Wait(); ctx.Wait();
PrintDebugInfo("output data", out_vec); PrintDebugInfo("output data", out_vec);
float diff = static_cast<float>(out_vec[0]) - 65504; float diff = static_cast<float>(out_vec[0]) - 65504;
EXPECT_TRUE(diff < 0.1 && diff > -0.1); EXPECT_TRUE(diff < 0.1 && diff > -0.1);
EXPECT_EQ(out_vec.size(), init.size()); EXPECT_EQ(out_vec.size(), init.size());
for (uint32_t i = 1; i < 10; i++) { for (uint32_t i = 1; i < 10; i++) {
EXPECT_EQ(out_vec[i], static_cast<paddle::platform::float16>(3.0)); EXPECT_EQ(out_vec[i], static_cast<paddle::platform::float16>(3.0));
} }
} }
TEST(c_allreduce_sum, NPU) { TEST(c_allreduce_sum, NPU) {
f::Scope scope; f::Scope scope;
HcclRootInfo hccl_id; HcclRootInfo hccl_id;
p::NPUDeviceContext ctx(p::NPUPlace(atoi(FLAGS_selected_npus.c_str()))); p::NPUDeviceContext ctx(p::NPUPlace(atoi(FLAGS_selected_npus.c_str())));
// only support one device, if more than one device, use first default // only support one device, if more than one device, use first default
PrepareUniqueId(&scope, ctx, &hccl_id); PrepareUniqueId(&scope, ctx, &hccl_id);
Prepare(&scope, ctx, &hccl_id); Prepare(&scope, ctx, &hccl_id);
TestHCCLAllReduceOp<paddle::platform::float16>(&scope, ctx, 1); TestHCCLAllReduceOp<paddle::platform::float16>(&scope, ctx, 1);
// TestHCCLAllReduceOp<float>(&scope, ctx, 0); // TestHCCLAllReduceOp<float>(&scope, ctx, 0);
} }
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifndef _WIN32 #ifndef _WIN32
#include <unistd.h> #include <unistd.h>
#endif #endif
#include <stdio.h> #include <stdio.h>
#include <string> #include <string>
#include <thread> // NOLINT #include <thread> // NOLINT
#include <vector> #include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/collective/c_broadcast_op.h" #include "paddle/fluid/operators/collective/c_broadcast_op.h"
#include "paddle/fluid/operators/collective/gen_hccl_id_op_helper.h" #include "paddle/fluid/operators/collective/gen_hccl_id_op_helper.h"
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#if defined(PADDLE_WITH_ASCEND_CL) #if defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/npu/hccl_helper.h" #include "paddle/fluid/platform/device/npu/hccl_helper.h"
#endif #endif
namespace f = paddle::framework; namespace f = paddle::framework;
namespace p = paddle::platform; namespace p = paddle::platform;
USE_OP(c_broadcast); USE_OP(c_broadcast);
USE_NO_KERNEL_OP(c_gen_hccl_id); USE_NO_KERNEL_OP(c_gen_hccl_id);
USE_NO_KERNEL_OP(c_comm_init_hccl); USE_NO_KERNEL_OP(c_comm_init_hccl);
USE_OP_DEVICE_KERNEL(c_broadcast, NPU); USE_OP_DEVICE_KERNEL(c_broadcast, NPU);
DECLARE_string(selected_npus); DECLARE_string(selected_npus);
template <typename T> template <typename T>
void PrintDebugInfo(const std::string preStr, const std::vector<T>& data) { void PrintDebugInfo(const std::string preStr, const std::vector<T>& data) {
std::string debugstring = ""; std::string debugstring = "";
for (auto ele : data) { for (auto ele : data) {
debugstring += std::to_string(ele) + std::string(","); debugstring += std::to_string(ele) + std::string(",");
} }
VLOG(2) << preStr << ":" << std::endl << debugstring; VLOG(2) << preStr << ":" << std::endl << debugstring;
} }
void PrepareUniqueId(f::Scope* scope, void PrepareUniqueId(f::Scope* scope,
const p::DeviceContext& ctx, const p::DeviceContext& ctx,
HcclRootInfo* hccl_id) { HcclRootInfo* hccl_id) {
int rank_id = atoi(getenv("RANK_ID")); int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID")); int device_id = atoi(getenv("DEVICE_ID"));
VLOG(2) << "rank_id = " << rank_id << "; device_id = " << device_id VLOG(2) << "rank_id = " << rank_id << "; device_id = " << device_id
<< "; rank_id = " << rank_id << "; rank_id = " << rank_id
<< "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID")); << "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID"));
std::vector<int> rank_ids{0, 1}; std::vector<int> rank_ids{0, 1};
f::AttributeMap gen_hccl_id; f::AttributeMap gen_hccl_id;
std::vector<std::string> endpointList = {"127.0.0.1:6175", "127.0.0.1:6177"}; std::vector<std::string> endpointList = {"127.0.0.1:6175", "127.0.0.1:6177"};
gen_hccl_id["rank"] = rank_id; gen_hccl_id["rank"] = rank_id;
gen_hccl_id["endpoint"] = endpointList[rank_id]; gen_hccl_id["endpoint"] = endpointList[rank_id];
std::vector<std::string> other_endpoints = { std::vector<std::string> other_endpoints = {
endpointList[rank_id == 0 ? 1 : 0]}; endpointList[rank_id == 0 ? 1 : 0]};
gen_hccl_id["other_endpoints"] = other_endpoints; gen_hccl_id["other_endpoints"] = other_endpoints;
auto out = scope->Var("Out"); auto out = scope->Var("Out");
auto id = out->GetMutable<HcclRootInfo>(); auto id = out->GetMutable<HcclRootInfo>();
VLOG(3) << "break"; VLOG(3) << "break";
auto comm_init_op = f::OpRegistry::CreateOp( auto comm_init_op = f::OpRegistry::CreateOp(
"c_gen_hccl_id", {}, {{"Out", {"Out"}}}, gen_hccl_id); "c_gen_hccl_id", {}, {{"Out", {"Out"}}}, gen_hccl_id);
VLOG(3) << "break"; VLOG(3) << "break";
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place); comm_init_op->Run(*scope, place);
ctx.Wait(); ctx.Wait();
memcpy(hccl_id, id, 1024); memcpy(hccl_id, id, 1024);
} }
void Prepare(f::Scope* scope, void Prepare(f::Scope* scope,
const p::DeviceContext& ctx, const p::DeviceContext& ctx,
HcclRootInfo* hccl_id) { HcclRootInfo* hccl_id) {
auto x = scope->Var("X"); auto x = scope->Var("X");
auto id = x->GetMutable<HcclRootInfo>(); auto id = x->GetMutable<HcclRootInfo>();
memcpy(id, hccl_id, 1024); memcpy(id, hccl_id, 1024);
int rank_id = atoi(getenv("RANK_ID")); int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID")); int device_id = atoi(getenv("DEVICE_ID"));
VLOG(2) << "rank_id = " << rank_id << "; device_id = " << device_id VLOG(2) << "rank_id = " << rank_id << "; device_id = " << device_id
<< "; rank_id = " << rank_id << "; rank_id = " << rank_id
<< "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID")); << "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID"));
// std::vector<int> rank_ids{0, 1}; // std::vector<int> rank_ids{0, 1};
f::AttributeMap comm_init_attrs; f::AttributeMap comm_init_attrs;
comm_init_attrs["ring_id"] = 0; comm_init_attrs["ring_id"] = 0;
comm_init_attrs["rank_ids"] = 2; comm_init_attrs["rank_ids"] = 2;
comm_init_attrs["rank"] = rank_id; comm_init_attrs["rank"] = rank_id;
comm_init_attrs["device_id"] = device_id; comm_init_attrs["device_id"] = device_id;
// comm_init_attrs["rank_ids"] = rank_ids; // comm_init_attrs["rank_ids"] = rank_ids;
auto comm_init_op = f::OpRegistry::CreateOp( auto comm_init_op = f::OpRegistry::CreateOp(
"c_comm_init_hccl", {{"X", {"X"}}}, {}, comm_init_attrs); "c_comm_init_hccl", {{"X", {"X"}}}, {}, comm_init_attrs);
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place); comm_init_op->Run(*scope, place);
ctx.Wait(); ctx.Wait();
} }
void TestHCCLBroadcastOp(f::Scope* scope, const p::DeviceContext& ctx) { void TestHCCLBroadcastOp(f::Scope* scope, const p::DeviceContext& ctx) {
// init // init
auto x = scope->Var("Data"); auto x = scope->Var("Data");
auto tensor_x = x->GetMutable<f::LoDTensor>(); auto tensor_x = x->GetMutable<f::LoDTensor>();
int num = 2; int num = 2;
std::vector<float> init; std::vector<float> init;
int rank_id = atoi(getenv("RANK_ID")); int rank_id = atoi(getenv("RANK_ID"));
for (int64_t i = 0; i < num * num; ++i) { for (int64_t i = 0; i < num * num; ++i) {
init.push_back(1.0 + rank_id); init.push_back(1.0 + rank_id);
} }
PrintDebugInfo("input data", init); PrintDebugInfo("input data", init);
paddle::framework::TensorFromVector(init, ctx, tensor_x); paddle::framework::TensorFromVector(init, ctx, tensor_x);
tensor_x->Resize({num, num}); tensor_x->Resize({num, num});
ctx.Wait(); ctx.Wait();
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
auto out = scope->Var("OutData"); auto out = scope->Var("OutData");
auto tensor_out = out->GetMutable<f::LoDTensor>(); auto tensor_out = out->GetMutable<f::LoDTensor>();
tensor_out->Resize({num, num}); tensor_out->Resize({num, num});
tensor_out->mutable_data<float>(place); // allocate tensor_out->mutable_data<float>(place); // allocate
ctx.Wait(); ctx.Wait();
// run // run
f::AttributeMap attrs; f::AttributeMap attrs;
attrs["tag"] = std::string("tagx"); attrs["tag"] = std::string("tagx");
attrs["root"] = 0; attrs["root"] = 0;
attrs["ring_id"] = 0; attrs["ring_id"] = 0;
auto op = f::OpRegistry::CreateOp( auto op = f::OpRegistry::CreateOp(
"c_broadcast", {{"X", {"Data"}}}, {{"Out", {"OutData"}}}, attrs); "c_broadcast", {{"X", {"Data"}}}, {{"Out", {"OutData"}}}, attrs);
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
op->Run(*scope, place); op->Run(*scope, place);
} }
ctx.Wait(); ctx.Wait();
std::vector<float> out_vec; std::vector<float> out_vec;
paddle::framework::TensorToVector(*tensor_out, ctx, &out_vec); paddle::framework::TensorToVector(*tensor_out, ctx, &out_vec);
ctx.Wait(); ctx.Wait();
PrintDebugInfo("output data", out_vec); PrintDebugInfo("output data", out_vec);
EXPECT_EQ(out_vec.size(), init.size()); EXPECT_EQ(out_vec.size(), init.size());
for (uint32_t i = 0; i < out_vec.size(); i++) { for (uint32_t i = 0; i < out_vec.size(); i++) {
EXPECT_EQ(out_vec[i], 1.0); EXPECT_EQ(out_vec[i], 1.0);
} }
} }
TEST(c_broadcast, NPU) { TEST(c_broadcast, NPU) {
f::Scope scope; f::Scope scope;
HcclRootInfo hccl_id; HcclRootInfo hccl_id;
// only support one device, if more than one device, use first default // only support one device, if more than one device, use first default
p::NPUDeviceContext ctx(p::NPUPlace(atoi(FLAGS_selected_npus.c_str()))); p::NPUDeviceContext ctx(p::NPUPlace(atoi(FLAGS_selected_npus.c_str())));
PrepareUniqueId(&scope, ctx, &hccl_id); PrepareUniqueId(&scope, ctx, &hccl_id);
Prepare(&scope, ctx, &hccl_id); Prepare(&scope, ctx, &hccl_id);
TestHCCLBroadcastOp(&scope, ctx); TestHCCLBroadcastOp(&scope, ctx);
} }
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifndef _WIN32 #ifndef _WIN32
#include <unistd.h> #include <unistd.h>
#endif #endif
#include <stdio.h> #include <stdio.h>
#include <string> #include <string>
#include <thread> // NOLINT #include <thread> // NOLINT
#include <vector> #include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/collective/c_reduce_op.h" #include "paddle/fluid/operators/collective/c_reduce_op.h"
#include "paddle/fluid/operators/collective/gen_hccl_id_op_helper.h" #include "paddle/fluid/operators/collective/gen_hccl_id_op_helper.h"
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#if defined(PADDLE_WITH_ASCEND_CL) #if defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/npu/hccl_helper.h" #include "paddle/fluid/platform/device/npu/hccl_helper.h"
#endif #endif
namespace f = paddle::framework; namespace f = paddle::framework;
namespace p = paddle::platform; namespace p = paddle::platform;
USE_OP(c_reduce_sum); USE_OP(c_reduce_sum);
USE_NO_KERNEL_OP(c_gen_hccl_id); USE_NO_KERNEL_OP(c_gen_hccl_id);
USE_NO_KERNEL_OP(c_comm_init_hccl); USE_NO_KERNEL_OP(c_comm_init_hccl);
USE_OP_DEVICE_KERNEL(c_reduce_sum, NPU); USE_OP_DEVICE_KERNEL(c_reduce_sum, NPU);
DECLARE_string(selected_npus); DECLARE_string(selected_npus);
template <typename T> template <typename T>
void PrintDebugInfo(const std::string preStr, const std::vector<T>& data) { void PrintDebugInfo(const std::string preStr, const std::vector<T>& data) {
std::string debugstring = ""; std::string debugstring = "";
for (auto ele : data) { for (auto ele : data) {
debugstring += std::to_string(ele) + std::string(","); debugstring += std::to_string(ele) + std::string(",");
} }
VLOG(3) << preStr << ":" << std::endl << debugstring; VLOG(3) << preStr << ":" << std::endl << debugstring;
} }
void PrepareUniqueId(f::Scope* scope, void PrepareUniqueId(f::Scope* scope,
const p::DeviceContext& ctx, const p::DeviceContext& ctx,
HcclRootInfo* hccl_id) { HcclRootInfo* hccl_id) {
int rank_id = atoi(getenv("RANK_ID")); int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID")); int device_id = atoi(getenv("DEVICE_ID"));
VLOG(2) << "rank_id = " << rank_id << "; device_id = " << device_id VLOG(2) << "rank_id = " << rank_id << "; device_id = " << device_id
<< "; rank_id = " << rank_id << "; rank_id = " << rank_id
<< "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID")); << "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID"));
std::vector<int> rank_ids{0, 1}; std::vector<int> rank_ids{0, 1};
f::AttributeMap gen_hccl_id; f::AttributeMap gen_hccl_id;
std::vector<std::string> endpointList = {"127.0.0.1:6175", "127.0.0.1:6177"}; std::vector<std::string> endpointList = {"127.0.0.1:6175", "127.0.0.1:6177"};
gen_hccl_id["rank"] = rank_id; gen_hccl_id["rank"] = rank_id;
gen_hccl_id["endpoint"] = endpointList[rank_id]; gen_hccl_id["endpoint"] = endpointList[rank_id];
std::vector<std::string> other_endpoints = { std::vector<std::string> other_endpoints = {
endpointList[rank_id == 0 ? 1 : 0]}; endpointList[rank_id == 0 ? 1 : 0]};
gen_hccl_id["other_endpoints"] = other_endpoints; gen_hccl_id["other_endpoints"] = other_endpoints;
auto out = scope->Var("Out"); auto out = scope->Var("Out");
auto id = out->GetMutable<HcclRootInfo>(); auto id = out->GetMutable<HcclRootInfo>();
VLOG(3) << "break"; VLOG(3) << "break";
auto comm_init_op = f::OpRegistry::CreateOp( auto comm_init_op = f::OpRegistry::CreateOp(
"c_gen_hccl_id", {}, {{"Out", {"Out"}}}, gen_hccl_id); "c_gen_hccl_id", {}, {{"Out", {"Out"}}}, gen_hccl_id);
VLOG(3) << "break"; VLOG(3) << "break";
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place); comm_init_op->Run(*scope, place);
ctx.Wait(); ctx.Wait();
memcpy(hccl_id, id, 1024); memcpy(hccl_id, id, 1024);
} }
void Prepare(f::Scope* scope, void Prepare(f::Scope* scope,
const p::DeviceContext& ctx, const p::DeviceContext& ctx,
HcclRootInfo* hccl_id) { HcclRootInfo* hccl_id) {
auto x = scope->Var("X"); auto x = scope->Var("X");
auto id = x->GetMutable<HcclRootInfo>(); auto id = x->GetMutable<HcclRootInfo>();
memcpy(id, hccl_id, 1024); memcpy(id, hccl_id, 1024);
int rank_id = atoi(getenv("RANK_ID")); int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID")); int device_id = atoi(getenv("DEVICE_ID"));
VLOG(2) << "rank_id = " << rank_id << "; device_id = " << device_id VLOG(2) << "rank_id = " << rank_id << "; device_id = " << device_id
<< "; rank_id = " << rank_id << "; rank_id = " << rank_id
<< "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID")); << "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID"));
// std::vector<int> rank_ids{0, 1}; // std::vector<int> rank_ids{0, 1};
f::AttributeMap comm_init_attrs; f::AttributeMap comm_init_attrs;
comm_init_attrs["ring_id"] = 0; comm_init_attrs["ring_id"] = 0;
comm_init_attrs["rank_ids"] = 2; comm_init_attrs["rank_ids"] = 2;
comm_init_attrs["rank"] = rank_id; comm_init_attrs["rank"] = rank_id;
comm_init_attrs["device_id"] = device_id; comm_init_attrs["device_id"] = device_id;
// comm_init_attrs["rank_ids"] = rank_ids; // comm_init_attrs["rank_ids"] = rank_ids;
auto comm_init_op = f::OpRegistry::CreateOp( auto comm_init_op = f::OpRegistry::CreateOp(
"c_comm_init_hccl", {{"X", {"X"}}}, {}, comm_init_attrs); "c_comm_init_hccl", {{"X", {"X"}}}, {}, comm_init_attrs);
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place); comm_init_op->Run(*scope, place);
ctx.Wait(); ctx.Wait();
} }
void TestHCCLReduceOp(f::Scope* scope, const p::DeviceContext& ctx, int iter) { void TestHCCLReduceOp(f::Scope* scope, const p::DeviceContext& ctx, int iter) {
// init // init
auto x = scope->Var("Data"); auto x = scope->Var("Data");
auto tensor_x = x->GetMutable<f::LoDTensor>(); auto tensor_x = x->GetMutable<f::LoDTensor>();
int rank_id = atoi(getenv("RANK_ID")); int rank_id = atoi(getenv("RANK_ID"));
int num1 = 3; int num1 = 3;
int num2 = 128; int num2 = 128;
std::vector<float> init; std::vector<float> init;
for (int64_t i = 0; i < num1 * num2; ++i) { for (int64_t i = 0; i < num1 * num2; ++i) {
init.push_back(1.0 + rank_id); init.push_back(1.0 + rank_id);
} }
PrintDebugInfo("input data", init); PrintDebugInfo("input data", init);
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
paddle::framework::TensorFromVector(init, ctx, tensor_x); paddle::framework::TensorFromVector(init, ctx, tensor_x);
tensor_x->Resize({num1, num2}); tensor_x->Resize({num1, num2});
ctx.Wait(); ctx.Wait();
auto out = scope->Var("OutData"); auto out = scope->Var("OutData");
auto tensor_out = out->GetMutable<f::LoDTensor>(); auto tensor_out = out->GetMutable<f::LoDTensor>();
tensor_out->Resize({num1, num2}); tensor_out->Resize({num1, num2});
tensor_out->mutable_data<float>(place); // allocate tensor_out->mutable_data<float>(place); // allocate
ctx.Wait(); ctx.Wait();
// run // run
f::AttributeMap attrs; f::AttributeMap attrs;
attrs["tag"] = std::string("tagx_" + std::to_string(iter)); attrs["tag"] = std::string("tagx_" + std::to_string(iter));
attrs["ring_id"] = 0; attrs["ring_id"] = 0;
int root_id = 0; int root_id = 0;
attrs["root_id"] = root_id; attrs["root_id"] = root_id;
auto op = f::OpRegistry::CreateOp( auto op = f::OpRegistry::CreateOp(
"c_reduce_sum", {{"X", {"Data"}}}, {{"Out", {"OutData"}}}, attrs); "c_reduce_sum", {{"X", {"Data"}}}, {{"Out", {"OutData"}}}, attrs);
op->Run(*scope, place); op->Run(*scope, place);
ctx.Wait(); ctx.Wait();
std::vector<float> out_vec; std::vector<float> out_vec;
paddle::framework::TensorToVector(*tensor_out, ctx, &out_vec); paddle::framework::TensorToVector(*tensor_out, ctx, &out_vec);
ctx.Wait(); ctx.Wait();
PrintDebugInfo("output data", out_vec); PrintDebugInfo("output data", out_vec);
EXPECT_EQ(out_vec.size(), init.size()); EXPECT_EQ(out_vec.size(), init.size());
for (uint32_t i = 0; i < out_vec.size(); i++) { for (uint32_t i = 0; i < out_vec.size(); i++) {
if (rank_id == root_id) { if (rank_id == root_id) {
EXPECT_EQ(out_vec[i], 3.0); EXPECT_EQ(out_vec[i], 3.0);
} else { } else {
EXPECT_EQ(out_vec[i], init[i]); EXPECT_EQ(out_vec[i], init[i]);
} }
} }
} }
TEST(c_reduce_sum, NPU) { TEST(c_reduce_sum, NPU) {
f::Scope scope; f::Scope scope;
HcclRootInfo hccl_id; HcclRootInfo hccl_id;
// only support one device, if more than one device, use first default // only support one device, if more than one device, use first default
p::NPUDeviceContext ctx(p::NPUPlace(atoi(FLAGS_selected_npus.c_str()))); p::NPUDeviceContext ctx(p::NPUPlace(atoi(FLAGS_selected_npus.c_str())));
PrepareUniqueId(&scope, ctx, &hccl_id); PrepareUniqueId(&scope, ctx, &hccl_id);
Prepare(&scope, ctx, &hccl_id); Prepare(&scope, ctx, &hccl_id);
for (int i = 0; i < 2; i++) { for (int i = 0; i < 2; i++) {
VLOG(2) << "iter num: " << i; VLOG(2) << "iter num: " << i;
TestHCCLReduceOp(&scope, ctx, i); TestHCCLReduceOp(&scope, ctx, i);
} }
} }
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifndef _WIN32 #ifndef _WIN32
#include <unistd.h> #include <unistd.h>
#endif #endif
#include <stdio.h> #include <stdio.h>
#include <string> #include <string>
#include <thread> // NOLINT #include <thread> // NOLINT
#include <vector> #include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/collective/c_allgather_op.h" #include "paddle/fluid/operators/collective/c_allgather_op.h"
#include "paddle/fluid/operators/collective/c_allreduce_op.h" #include "paddle/fluid/operators/collective/c_allreduce_op.h"
#include "paddle/fluid/operators/collective/c_broadcast_op.h" #include "paddle/fluid/operators/collective/c_broadcast_op.h"
#include "paddle/fluid/operators/collective/c_reducescatter_op.h" #include "paddle/fluid/operators/collective/c_reducescatter_op.h"
#include "paddle/fluid/operators/collective/gen_hccl_id_op_helper.h" #include "paddle/fluid/operators/collective/gen_hccl_id_op_helper.h"
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#if defined(PADDLE_WITH_ASCEND_CL) #if defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/npu/hccl_helper.h" #include "paddle/fluid/platform/device/npu/hccl_helper.h"
#endif #endif
namespace f = paddle::framework; namespace f = paddle::framework;
namespace p = paddle::platform; namespace p = paddle::platform;
USE_OP(c_reducescatter); USE_OP(c_reducescatter);
USE_NO_KERNEL_OP(c_gen_hccl_id); USE_NO_KERNEL_OP(c_gen_hccl_id);
USE_NO_KERNEL_OP(c_comm_init_hccl); USE_NO_KERNEL_OP(c_comm_init_hccl);
USE_OP_DEVICE_KERNEL(c_reducescatter, NPU); USE_OP_DEVICE_KERNEL(c_reducescatter, NPU);
DECLARE_string(selected_npus); DECLARE_string(selected_npus);
template <typename T> template <typename T>
void PrintDebugInfo(const std::string preStr, const std::vector<T>& data) { void PrintDebugInfo(const std::string preStr, const std::vector<T>& data) {
std::string debugstring = ""; std::string debugstring = "";
for (auto ele : data) { for (auto ele : data) {
debugstring += std::to_string(ele) + std::string(","); debugstring += std::to_string(ele) + std::string(",");
} }
VLOG(2) << preStr << ":" << std::endl << debugstring; VLOG(2) << preStr << ":" << std::endl << debugstring;
} }
void PrepareUniqueId(f::Scope* scope, void PrepareUniqueId(f::Scope* scope,
const p::DeviceContext& ctx, const p::DeviceContext& ctx,
HcclRootInfo* hccl_id) { HcclRootInfo* hccl_id) {
int rank_id = atoi(getenv("RANK_ID")); int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID")); int device_id = atoi(getenv("DEVICE_ID"));
VLOG(2) << "rank_id = " << rank_id << "; device_id = " << device_id VLOG(2) << "rank_id = " << rank_id << "; device_id = " << device_id
<< "; rank_id = " << rank_id << "; rank_id = " << rank_id
<< "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID")); << "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID"));
std::vector<int> rank_ids{0, 1}; std::vector<int> rank_ids{0, 1};
f::AttributeMap gen_hccl_id; f::AttributeMap gen_hccl_id;
std::vector<std::string> endpointList = {"127.0.0.1:6175", "127.0.0.1:6177"}; std::vector<std::string> endpointList = {"127.0.0.1:6175", "127.0.0.1:6177"};
gen_hccl_id["rank"] = rank_id; gen_hccl_id["rank"] = rank_id;
gen_hccl_id["endpoint"] = endpointList[rank_id]; gen_hccl_id["endpoint"] = endpointList[rank_id];
std::vector<std::string> other_endpoints = { std::vector<std::string> other_endpoints = {
endpointList[rank_id == 0 ? 1 : 0]}; endpointList[rank_id == 0 ? 1 : 0]};
gen_hccl_id["other_endpoints"] = other_endpoints; gen_hccl_id["other_endpoints"] = other_endpoints;
auto out = scope->Var("Out"); auto out = scope->Var("Out");
auto id = out->GetMutable<HcclRootInfo>(); auto id = out->GetMutable<HcclRootInfo>();
VLOG(3) << "break"; VLOG(3) << "break";
auto comm_init_op = f::OpRegistry::CreateOp( auto comm_init_op = f::OpRegistry::CreateOp(
"c_gen_hccl_id", {}, {{"Out", {"Out"}}}, gen_hccl_id); "c_gen_hccl_id", {}, {{"Out", {"Out"}}}, gen_hccl_id);
VLOG(3) << "break"; VLOG(3) << "break";
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place); comm_init_op->Run(*scope, place);
ctx.Wait(); ctx.Wait();
memcpy(hccl_id, id, 1024); memcpy(hccl_id, id, 1024);
} }
void Prepare(f::Scope* scope, void Prepare(f::Scope* scope,
const p::DeviceContext& ctx, const p::DeviceContext& ctx,
HcclRootInfo* hccl_id) { HcclRootInfo* hccl_id) {
auto x = scope->Var("X"); auto x = scope->Var("X");
auto id = x->GetMutable<HcclRootInfo>(); auto id = x->GetMutable<HcclRootInfo>();
memcpy(id, hccl_id, 1024); memcpy(id, hccl_id, 1024);
int rank_id = atoi(getenv("RANK_ID")); int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID")); int device_id = atoi(getenv("DEVICE_ID"));
VLOG(2) << "rank_id = " << rank_id << "; device_id = " << device_id VLOG(2) << "rank_id = " << rank_id << "; device_id = " << device_id
<< "; rank_id = " << rank_id << "; rank_id = " << rank_id
<< "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID")); << "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID"));
// std::vector<int> rank_ids{0, 1}; // std::vector<int> rank_ids{0, 1};
f::AttributeMap comm_init_attrs; f::AttributeMap comm_init_attrs;
comm_init_attrs["ring_id"] = 0; comm_init_attrs["ring_id"] = 0;
comm_init_attrs["rank_ids"] = 2; comm_init_attrs["rank_ids"] = 2;
comm_init_attrs["rank"] = rank_id; comm_init_attrs["rank"] = rank_id;
comm_init_attrs["device_id"] = device_id; comm_init_attrs["device_id"] = device_id;
// comm_init_attrs["rank_ids"] = rank_ids; // comm_init_attrs["rank_ids"] = rank_ids;
auto comm_init_op = f::OpRegistry::CreateOp( auto comm_init_op = f::OpRegistry::CreateOp(
"c_comm_init_hccl", {{"X", {"X"}}}, {}, comm_init_attrs); "c_comm_init_hccl", {{"X", {"X"}}}, {}, comm_init_attrs);
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place); comm_init_op->Run(*scope, place);
ctx.Wait(); ctx.Wait();
} }
void TestHCCLReduceScatterOp(f::Scope* scope, const p::DeviceContext& ctx) { void TestHCCLReduceScatterOp(f::Scope* scope, const p::DeviceContext& ctx) {
// init // init
auto x = scope->Var("Data"); auto x = scope->Var("Data");
auto tensor_x = x->GetMutable<f::LoDTensor>(); auto tensor_x = x->GetMutable<f::LoDTensor>();
std::vector<float> init; std::vector<float> init;
int num1 = 4; int num1 = 4;
int num2 = 1; int num2 = 1;
for (int64_t i = 0; i < num1 * num2; ++i) { for (int64_t i = 0; i < num1 * num2; ++i) {
init.push_back(1.0); init.push_back(1.0);
} }
PrintDebugInfo("input data", init); PrintDebugInfo("input data", init);
paddle::framework::TensorFromVector(init, ctx, tensor_x); paddle::framework::TensorFromVector(init, ctx, tensor_x);
tensor_x->Resize({num1, num2}); tensor_x->Resize({num1, num2});
ctx.Wait(); ctx.Wait();
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
auto out = scope->Var("OutData"); auto out = scope->Var("OutData");
auto tensor_out = out->GetMutable<f::LoDTensor>(); auto tensor_out = out->GetMutable<f::LoDTensor>();
tensor_out->Resize({num1, num2}); tensor_out->Resize({num1, num2});
tensor_out->mutable_data<float>(place); // allocate tensor_out->mutable_data<float>(place); // allocate
ctx.Wait(); ctx.Wait();
// run // run
f::AttributeMap attrs; f::AttributeMap attrs;
attrs["tag"] = std::string("tagx"); attrs["tag"] = std::string("tagx");
attrs["ring_id"] = 0; attrs["ring_id"] = 0;
attrs["nranks"] = 2; attrs["nranks"] = 2;
auto op = f::OpRegistry::CreateOp( auto op = f::OpRegistry::CreateOp(
"c_reducescatter", {{"X", {"Data"}}}, {{"Out", {"OutData"}}}, attrs); "c_reducescatter", {{"X", {"Data"}}}, {{"Out", {"OutData"}}}, attrs);
int iter_num = 10; int iter_num = 10;
for (int i = 0; i < iter_num; i++) { for (int i = 0; i < iter_num; i++) {
op->Run(*scope, place); op->Run(*scope, place);
ctx.Wait(); ctx.Wait();
} }
std::vector<float> out_vec; std::vector<float> out_vec;
paddle::framework::TensorToVector(*tensor_out, ctx, &out_vec); paddle::framework::TensorToVector(*tensor_out, ctx, &out_vec);
ctx.Wait(); ctx.Wait();
PrintDebugInfo("output data", out_vec); PrintDebugInfo("output data", out_vec);
EXPECT_EQ(out_vec.size(), init.size() / 2); EXPECT_EQ(out_vec.size(), init.size() / 2);
for (uint32_t i = 0; i < out_vec.size(); i++) { for (uint32_t i = 0; i < out_vec.size(); i++) {
EXPECT_EQ(out_vec[i], 2.0); EXPECT_EQ(out_vec[i], 2.0);
} }
} }
TEST(c_reducescatter, NPU) { TEST(c_reducescatter, NPU) {
f::Scope scope; f::Scope scope;
HcclRootInfo hccl_id; HcclRootInfo hccl_id;
// only support one device, if more than one device, use first default // only support one device, if more than one device, use first default
p::NPUDeviceContext ctx(p::NPUPlace(atoi(FLAGS_selected_npus.c_str()))); p::NPUDeviceContext ctx(p::NPUPlace(atoi(FLAGS_selected_npus.c_str())));
PrepareUniqueId(&scope, ctx, &hccl_id); PrepareUniqueId(&scope, ctx, &hccl_id);
Prepare(&scope, ctx, &hccl_id); Prepare(&scope, ctx, &hccl_id);
TestHCCLReduceScatterOp(&scope, ctx); TestHCCLReduceScatterOp(&scope, ctx);
} }
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/utils.h" #include "paddle/fluid/operators/utils.h"
#include "paddle/fluid/platform/mkldnn_reuse.h" #include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using framework::Tensor; using framework::Tensor;
template <typename T> template <typename T>
class FillConstantMKLDNNHandler class FillConstantMKLDNNHandler
: public platform::MKLDNNHandlerNoCachingT<T, dnnl::binary> { : public platform::MKLDNNHandlerNoCachingT<T, dnnl::binary> {
public: public:
FillConstantMKLDNNHandler(Tensor* out, FillConstantMKLDNNHandler(Tensor* out,
dnnl::engine engine, dnnl::engine engine,
platform::Place cpu_place) platform::Place cpu_place)
: platform::MKLDNNHandlerNoCachingT<T, dnnl::binary>(engine, cpu_place) { : platform::MKLDNNHandlerNoCachingT<T, dnnl::binary>(engine, cpu_place) {
const auto src0_md = const auto src0_md =
dnnl::memory::desc({out->numel(), sizeof(T)}, dnnl::memory::desc({out->numel(), sizeof(T)},
platform::MKLDNNGetDataType<uint8_t>(), platform::MKLDNNGetDataType<uint8_t>(),
dnnl::memory::format_tag::ab); dnnl::memory::format_tag::ab);
dnnl::primitive_attr attrs; dnnl::primitive_attr attrs;
attrs.set_scales(DNNL_ARG_SRC_0, /* mask = */ 0, {0.0f}); attrs.set_scales(DNNL_ARG_SRC_0, /* mask = */ 0, {0.0f});
this->AcquireForwardPrimitiveDescriptor( this->AcquireForwardPrimitiveDescriptor(
attrs, dnnl::algorithm::binary_add, src0_md, src1_md, src0_md); attrs, dnnl::algorithm::binary_add, src0_md, src1_md, src0_md);
} }
static const dnnl::memory::desc src1_md; static const dnnl::memory::desc src1_md;
}; };
template <typename T> template <typename T>
const dnnl::memory::desc FillConstantMKLDNNHandler<T>::src1_md( const dnnl::memory::desc FillConstantMKLDNNHandler<T>::src1_md(
{1, sizeof(T)}, {1, sizeof(T)},
platform::MKLDNNGetDataType<uint8_t>(), platform::MKLDNNGetDataType<uint8_t>(),
dnnl::memory::format_tag::ab); dnnl::memory::format_tag::ab);
template <typename T> template <typename T>
class FillConstantMKLDNNKernel : public framework::OpKernel<T> { class FillConstantMKLDNNKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
this->RunKernel(ctx); this->RunKernel(ctx);
} }
void RunKernel(const framework::ExecutionContext& ctx) const { void RunKernel(const framework::ExecutionContext& ctx) const {
const auto& dev_ctx = const auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>(); ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& dnnl_engine = dev_ctx.GetEngine(); const auto& dnnl_engine = dev_ctx.GetEngine();
auto* out = ctx.Output<Tensor>("Out"); auto* out = ctx.Output<Tensor>("Out");
T fill_value = CalculateFillValue(ctx); T fill_value = CalculateFillValue(ctx);
auto shape = GetShape(ctx); auto shape = GetShape(ctx);
out->Resize(shape); out->Resize(shape);
FillConstantMKLDNNHandler<T> handler(out, dnnl_engine, ctx.GetPlace()); FillConstantMKLDNNHandler<T> handler(out, dnnl_engine, ctx.GetPlace());
dnnl::memory constant_value_memory = dnnl::memory constant_value_memory =
dnnl::memory(FillConstantMKLDNNHandler<T>::src1_md, dnnl::memory(FillConstantMKLDNNHandler<T>::src1_md,
dnnl_engine, dnnl_engine,
reinterpret_cast<uint8_t*>(&fill_value)); reinterpret_cast<uint8_t*>(&fill_value));
auto src0_memory_p = handler.AcquireDstMemory(out); auto src0_memory_p = handler.AcquireDstMemory(out);
auto fill_constant_p = handler.AcquireForwardPrimitive(); auto fill_constant_p = handler.AcquireForwardPrimitive();
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
fill_constant_p->execute(astream, fill_constant_p->execute(astream,
{{DNNL_ARG_SRC_0, *src0_memory_p}, {{DNNL_ARG_SRC_0, *src0_memory_p},
{DNNL_ARG_SRC_1, constant_value_memory}, {DNNL_ARG_SRC_1, constant_value_memory},
{DNNL_ARG_DST, *src0_memory_p}}); {DNNL_ARG_DST, *src0_memory_p}});
astream.wait(); astream.wait();
// src0_memory_p's md was just to allow the usage of a binary // src0_memory_p's md was just to allow the usage of a binary
// primitive as a memset, and now we need to create a real one // primitive as a memset, and now we need to create a real one
out->set_mem_desc({phi::vectorize(shape), out->set_mem_desc({phi::vectorize(shape),
platform::MKLDNNGetDataType<T>(), platform::MKLDNNGetDataType<T>(),
platform::GetPlainMKLDNNFormat(shape.size())}); platform::GetPlainMKLDNNFormat(shape.size())});
} }
T CalculateFillValue(const framework::ExecutionContext& ctx) const { T CalculateFillValue(const framework::ExecutionContext& ctx) const {
const auto str_value = ctx.Attr<std::string>("str_value"); const auto str_value = ctx.Attr<std::string>("str_value");
const auto float_value = ctx.Attr<float>("value"); const auto float_value = ctx.Attr<float>("value");
T value; T value;
if (str_value.empty()) { if (str_value.empty()) {
value = static_cast<T>(float_value); value = static_cast<T>(float_value);
} else { } else {
// handle NaN/Inf first, which cannot be read from stream // handle NaN/Inf first, which cannot be read from stream
if (str_value == "inf") { if (str_value == "inf") {
value = static_cast<T>(std::numeric_limits<float>::infinity()); value = static_cast<T>(std::numeric_limits<float>::infinity());
} else if (str_value == "-inf") { } else if (str_value == "-inf") {
value = static_cast<T>(-std::numeric_limits<float>::infinity()); value = static_cast<T>(-std::numeric_limits<float>::infinity());
} else if (str_value == "nan") { } else if (str_value == "nan") {
value = static_cast<T>(std::numeric_limits<float>::quiet_NaN()); value = static_cast<T>(std::numeric_limits<float>::quiet_NaN());
} else { } else {
std::stringstream convert_stream(str_value); std::stringstream convert_stream(str_value);
double tmp_value; double tmp_value;
convert_stream >> tmp_value; convert_stream >> tmp_value;
value = static_cast<T>(tmp_value); value = static_cast<T>(tmp_value);
} }
} }
if (ctx.HasInput("ValueTensor")) { if (ctx.HasInput("ValueTensor")) {
const auto* value_tensor = ctx.Input<Tensor>("ValueTensor"); const auto* value_tensor = ctx.Input<Tensor>("ValueTensor");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
value_tensor->numel(), value_tensor->numel(),
1, 1,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"When use Tensor as value to set Tensor value in fill_constant, " "When use Tensor as value to set Tensor value in fill_constant, "
"value input(ValueTensor) size must be 1, but got %d", "value input(ValueTensor) size must be 1, but got %d",
value_tensor->numel())); value_tensor->numel()));
value = value_tensor->data<T>()[0]; value = value_tensor->data<T>()[0];
} }
return value; return value;
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_KERNEL(fill_constant, REGISTER_OP_KERNEL(fill_constant,
MKLDNN, MKLDNN,
paddle::platform::CPUPlace, paddle::platform::CPUPlace,
ops::FillConstantMKLDNNKernel<float>); ops::FillConstantMKLDNNKernel<float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/unbind_op.h" #include "paddle/fluid/operators/unbind_op.h"
#include <string> #include <string>
#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h" #include "paddle/phi/infermeta/unary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using framework::Tensor; using framework::Tensor;
class UnbindOp : public framework::OperatorWithKernel { class UnbindOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), ctx->HasInput("X"),
true, true,
platform::errors::NotFound("Input(X) of UnbindOp is not found.")); platform::errors::NotFound("Input(X) of UnbindOp is not found."));
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
ctx->Outputs("Out").size(), ctx->Outputs("Out").size(),
1UL, 1UL,
platform::errors::NotFound("Outputs(Out) of UnbindOp is not found.")); platform::errors::NotFound("Outputs(Out) of UnbindOp is not found."));
auto in_dims = ctx->GetInputDim("X"); auto in_dims = ctx->GetInputDim("X");
auto outs_names = ctx->Outputs("Out"); auto outs_names = ctx->Outputs("Out");
int axis = ctx->Attrs().Get<int>("axis"); int axis = ctx->Attrs().Get<int>("axis");
const size_t outs_number = outs_names.size(); const size_t outs_number = outs_names.size();
auto out_dims = UnbindOutsDims(in_dims, axis); auto out_dims = UnbindOutsDims(in_dims, axis);
std::vector<framework::DDim> outs_dims(outs_number, out_dims); std::vector<framework::DDim> outs_dims(outs_number, out_dims);
ctx->SetOutputsDim("Out", outs_dims); ctx->SetOutputsDim("Out", outs_dims);
for (size_t i = 0; i < outs_number; ++i) { for (size_t i = 0; i < outs_number; ++i) {
ctx->ShareLoD("X", "Out", 0, i); ctx->ShareLoD("X", "Out", 0, i);
} }
} }
}; };
class UnbindOpMaker : public framework::OpProtoAndCheckerMaker { class UnbindOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("X", "(Tensor) Input tensor of the split operator."); AddInput("X", "(Tensor) Input tensor of the split operator.");
AddOutput("Out", "(Tensor) Output tensors of the unbind operator.") AddOutput("Out", "(Tensor) Output tensors of the unbind operator.")
.AsDuplicable(); .AsDuplicable();
AddComment(R"DOC( AddComment(R"DOC(
Unbind operator Unbind operator
Remove a tensor dimension. Remove a tensor dimension.
Example: Example:
Input = [[1,2], Input = [[1,2],
[3,4], [3,4],
[5,6]] [5,6]]
axis = 0 axis = 0
Output[0] = [1,2] Output[0] = [1,2]
Output[1] = [3,4] Output[1] = [3,4]
Output[2] = [5,6] Output[2] = [5,6]
)DOC"); )DOC");
AddAttr<int>("axis", AddAttr<int>("axis",
"(int, default 0) " "(int, default 0) "
"dimension to remove.") "dimension to remove.")
.SetDefault(0); .SetDefault(0);
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(unbind, REGISTER_OPERATOR(unbind,
ops::UnbindOp, ops::UnbindOp,
ops::UnbindOpMaker, ops::UnbindOpMaker,
ops::UnbindGradMaker<paddle::framework::OpDesc>, ops::UnbindGradMaker<paddle::framework::OpDesc>,
ops::UnbindGradMaker<paddle::imperative::OpBase>); ops::UnbindGradMaker<paddle::imperative::OpBase>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <chrono> // NOLINT #include <chrono> // NOLINT
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/operators/strided_memcpy.h" #include "paddle/fluid/operators/strided_memcpy.h"
#include "paddle/fluid/operators/utils.h" #include "paddle/fluid/operators/utils.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
static inline framework::DDim UnbindOutsDims(const framework::DDim in_dims, static inline framework::DDim UnbindOutsDims(const framework::DDim in_dims,
int axis) { int axis) {
std::vector<int> out_dims; std::vector<int> out_dims;
axis = axis < 0 ? in_dims.size() + axis : axis; axis = axis < 0 ? in_dims.size() + axis : axis;
for (int i = 0; i < in_dims.size(); i++) { for (int i = 0; i < in_dims.size(); i++) {
if (i != axis) out_dims.push_back(in_dims[i]); if (i != axis) out_dims.push_back(in_dims[i]);
} }
return phi::make_ddim(out_dims); return phi::make_ddim(out_dims);
} }
template <typename T> template <typename T>
class UnbindGradMaker : public framework::SingleGradOpMaker<T> { class UnbindGradMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
void Apply(GradOpPtr<T> op) const override { void Apply(GradOpPtr<T> op) const override {
op->SetType("stack"); op->SetType("stack");
op->SetInput("X", this->OutputGrad("Out")); op->SetInput("X", this->OutputGrad("Out"));
op->SetOutput("Y", this->InputGrad("X")); op->SetOutput("Y", this->InputGrad("X"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册