adapter npu (#31926)

Co-authored-by: Nbaiyangfan <baiyangfan@baidu.com>
上级 ac89174e
......@@ -634,7 +634,7 @@ class PSGPUWorker : public HogwildWorker {
};
#endif
#if defined(PADDLE_WITH_NCCL)
#if (defined PADDLE_WITH_NCCL) || (defined WITH_ASCEND_CL)
class SectionWorker : public DeviceWorker {
public:
SectionWorker() {}
......
......@@ -76,7 +76,7 @@ REGISTER_DEVICE_WORKER_CLASS(HeterBoxWorker);
REGISTER_DEVICE_WORKER_CLASS(PSGPUWorker);
#endif
#if defined(PADDLE_WITH_NCCL)
#if (defined PADDLE_WITH_NCCL) || (defined WITH_ASCEND_CL)
REGISTER_DEVICE_WORKER_CLASS(SectionWorker);
#endif
} // namespace framework
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#if defined(PADDLE_WITH_NCCL)
#if (defined PADDLE_WITH_NCCL) || (defined WITH_ASCEND_CL)
#include <map>
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/device_worker_factory.h"
......@@ -35,7 +35,11 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc,
ParseDumpConfig(trainer_desc);
const auto& section_config = section_params.section_config();
int place_id = section_config.place_id();
#if (defined PADDLE_WITH_NCCL)
place_ = platform::CUDAPlace(place_id);
#elif (defined WITH_ASCEND_CL)
place_ = platform::NPUPlace(place_id);
#endif
worker_ = DeviceWorkerFactory::CreateDeviceWorker(
trainer_desc.device_worker_name());
auto this_worker =
......
......@@ -9,7 +9,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#if defined(PADDLE_WITH_NCCL)
#if (defined PADDLE_WITH_NCCL) || (defined WITH_ASCEND_CL)
#include <float.h>
#include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/executor_gc_helper.h"
......
......@@ -320,7 +320,7 @@ class PSGPUTrainer : public TrainerBase {
};
#endif
#if defined(PADDLE_WITH_NCCL)
#if (defined PADDLE_WITH_NCCL) || (defined WITH_ASCEND_CL)
class PipelineTrainer : public TrainerBase {
public:
PipelineTrainer() {}
......
......@@ -83,6 +83,7 @@ REGISTER_OP_NPU_KERNEL(
ops::CastNPUKernel<paddle::platform::NPUDeviceContext, int16_t>,
ops::CastNPUKernel<paddle::platform::NPUDeviceContext, int32_t>,
ops::CastNPUKernel<paddle::platform::NPUDeviceContext, int64_t>,
ops::CastNPUKernel<paddle::platform::NPUDeviceContext, int>,
ops::CastNPUKernel<paddle::platform::NPUDeviceContext, bool>,
ops::CastNPUKernel<paddle::platform::NPUDeviceContext, double>,
ops::CastNPUKernel<paddle::platform::NPUDeviceContext, float>,
......
......@@ -76,6 +76,7 @@ class ExpandNPUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
REGISTER_OP_NPU_KERNEL(
expand, ops::ExpandNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::ExpandNPUKernel<paddle::platform::NPUDeviceContext, int>,
ops::ExpandNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
......
......@@ -82,9 +82,11 @@ namespace ops = paddle::operators;
REGISTER_OP_NPU_KERNEL(
lookup_table_v2,
ops::LookupTableV2NPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::LookupTableV2NPUKernel<paddle::platform::NPUDeviceContext, int>,
ops::LookupTableV2NPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_NPU_KERNEL(
lookup_table_v2_grad, ops::LookupTableV2GradNPUKernel<float>,
ops::LookupTableV2GradNPUKernel<int>,
ops::LookupTableV2GradNPUKernel<paddle::platform::float16>);
......@@ -124,11 +124,13 @@ namespace ops = paddle::operators;
REGISTER_OP_NPU_KERNEL(
slice, ops::SliceNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::SliceNPUKernel<paddle::platform::NPUDeviceContext, int>,
ops::SliceNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_NPU_KERNEL(
slice_grad,
ops::SliceGradNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::SliceGradNPUKernel<paddle::platform::NPUDeviceContext, int>,
ops::SliceGradNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
......@@ -103,6 +103,8 @@ class ShardingOptimizer(MetaOptimizerBase):
self.pp_bz = self.user_defined_strategy.sharding_configs["pp_bz"]
self.pp_allreduce_in_optimize = self.user_defined_strategy.sharding_configs[
"pp_allreduce_in_optimize"]
self.optimize_offload = self.user_defined_strategy.sharding_configs[
"optimize_offload"]
if self.inner_opt is None:
raise ValueError(
......@@ -238,7 +240,7 @@ class ShardingOptimizer(MetaOptimizerBase):
#check_allreduce_sum(main_block, self._shard, self.sharding_ring_id,
# self.dp_ring_id)
#check_allreduce_sum(main_block, self._shard, self.dp_ring_id)
self._wait()
# self._wait()
return optimize_ops, params_grads
def _set_up(self, params_grads):
......
......@@ -424,7 +424,7 @@ class Section(DeviceWorker):
# cfg.program_desc.CopyFrom(program.program._get_desc())
place = pipeline_opt["place"]
place_id = pipeline_opt["place_id"]
assert isinstance(place, core.CUDAPlace)
# assert isinstance(place, core.CUDAPlace)
cfg.place = cfg.CUDAPlace
cfg.place_id = place_id
......
......@@ -5272,7 +5272,10 @@ class PipelineOptimizer(object):
place_list = []
for dev in device_list:
dev_index = int(dev.split(":")[1])
place_list.append(core.CUDAPlace(dev_index % 8))
if core.is_compiled_with_cuda():
place_list.append(core.CUDAPlace(dev_index % 1))
elif core.is_compiled_with_npu():
place_list.append(core.NPUPlace(dev_index % 1))
# Step6: Split startup program
new_startup_program = self._split_startup_program(startup_program,
......@@ -5295,7 +5298,10 @@ class PipelineOptimizer(object):
self._accumulate_gradients(real_block)
real_block._sync_with_cpp()
place_id = int(os.getenv("FLAGS_selected_gpus", "0"))
if core.is_compiled_with_cuda():
place_id = int(os.getenv("FLAGS_selected_gpus", "0"))
elif core.is_compiled_with_npu():
place_id = int(os.getenv("FLAGS_selected_npus", "0"))
main_program._pipeline_opt = {
"trainer": "PipelineTrainer",
"device_worker": "Section",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册