adapter npu (#31926)

Co-authored-by: Nbaiyangfan <baiyangfan@baidu.com>
上级 ac89174e
...@@ -634,7 +634,7 @@ class PSGPUWorker : public HogwildWorker { ...@@ -634,7 +634,7 @@ class PSGPUWorker : public HogwildWorker {
}; };
#endif #endif
#if defined(PADDLE_WITH_NCCL) #if (defined PADDLE_WITH_NCCL) || (defined WITH_ASCEND_CL)
class SectionWorker : public DeviceWorker { class SectionWorker : public DeviceWorker {
public: public:
SectionWorker() {} SectionWorker() {}
......
...@@ -76,7 +76,7 @@ REGISTER_DEVICE_WORKER_CLASS(HeterBoxWorker); ...@@ -76,7 +76,7 @@ REGISTER_DEVICE_WORKER_CLASS(HeterBoxWorker);
REGISTER_DEVICE_WORKER_CLASS(PSGPUWorker); REGISTER_DEVICE_WORKER_CLASS(PSGPUWorker);
#endif #endif
#if defined(PADDLE_WITH_NCCL) #if (defined PADDLE_WITH_NCCL) || (defined WITH_ASCEND_CL)
REGISTER_DEVICE_WORKER_CLASS(SectionWorker); REGISTER_DEVICE_WORKER_CLASS(SectionWorker);
#endif #endif
} // namespace framework } // namespace framework
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#if defined(PADDLE_WITH_NCCL) #if (defined PADDLE_WITH_NCCL) || (defined WITH_ASCEND_CL)
#include <map> #include <map>
#include "paddle/fluid/framework/data_feed_factory.h" #include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/device_worker_factory.h" #include "paddle/fluid/framework/device_worker_factory.h"
...@@ -35,7 +35,11 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc, ...@@ -35,7 +35,11 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc,
ParseDumpConfig(trainer_desc); ParseDumpConfig(trainer_desc);
const auto& section_config = section_params.section_config(); const auto& section_config = section_params.section_config();
int place_id = section_config.place_id(); int place_id = section_config.place_id();
#if (defined PADDLE_WITH_NCCL)
place_ = platform::CUDAPlace(place_id); place_ = platform::CUDAPlace(place_id);
#elif (defined WITH_ASCEND_CL)
place_ = platform::NPUPlace(place_id);
#endif
worker_ = DeviceWorkerFactory::CreateDeviceWorker( worker_ = DeviceWorkerFactory::CreateDeviceWorker(
trainer_desc.device_worker_name()); trainer_desc.device_worker_name());
auto this_worker = auto this_worker =
......
...@@ -9,7 +9,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -9,7 +9,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#if defined(PADDLE_WITH_NCCL) #if (defined PADDLE_WITH_NCCL) || (defined WITH_ASCEND_CL)
#include <float.h> #include <float.h>
#include "paddle/fluid/framework/device_worker.h" #include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/executor_gc_helper.h" #include "paddle/fluid/framework/executor_gc_helper.h"
......
...@@ -320,7 +320,7 @@ class PSGPUTrainer : public TrainerBase { ...@@ -320,7 +320,7 @@ class PSGPUTrainer : public TrainerBase {
}; };
#endif #endif
#if defined(PADDLE_WITH_NCCL) #if (defined PADDLE_WITH_NCCL) || (defined WITH_ASCEND_CL)
class PipelineTrainer : public TrainerBase { class PipelineTrainer : public TrainerBase {
public: public:
PipelineTrainer() {} PipelineTrainer() {}
......
...@@ -83,6 +83,7 @@ REGISTER_OP_NPU_KERNEL( ...@@ -83,6 +83,7 @@ REGISTER_OP_NPU_KERNEL(
ops::CastNPUKernel<paddle::platform::NPUDeviceContext, int16_t>, ops::CastNPUKernel<paddle::platform::NPUDeviceContext, int16_t>,
ops::CastNPUKernel<paddle::platform::NPUDeviceContext, int32_t>, ops::CastNPUKernel<paddle::platform::NPUDeviceContext, int32_t>,
ops::CastNPUKernel<paddle::platform::NPUDeviceContext, int64_t>, ops::CastNPUKernel<paddle::platform::NPUDeviceContext, int64_t>,
ops::CastNPUKernel<paddle::platform::NPUDeviceContext, int>,
ops::CastNPUKernel<paddle::platform::NPUDeviceContext, bool>, ops::CastNPUKernel<paddle::platform::NPUDeviceContext, bool>,
ops::CastNPUKernel<paddle::platform::NPUDeviceContext, double>, ops::CastNPUKernel<paddle::platform::NPUDeviceContext, double>,
ops::CastNPUKernel<paddle::platform::NPUDeviceContext, float>, ops::CastNPUKernel<paddle::platform::NPUDeviceContext, float>,
......
...@@ -76,6 +76,7 @@ class ExpandNPUKernel : public framework::OpKernel<T> { ...@@ -76,6 +76,7 @@ class ExpandNPUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_NPU_KERNEL( REGISTER_OP_NPU_KERNEL(
expand, ops::ExpandNPUKernel<paddle::platform::NPUDeviceContext, float>, expand, ops::ExpandNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::ExpandNPUKernel<paddle::platform::NPUDeviceContext, int>,
ops::ExpandNPUKernel<paddle::platform::NPUDeviceContext, ops::ExpandNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>); paddle::platform::float16>);
......
...@@ -82,9 +82,11 @@ namespace ops = paddle::operators; ...@@ -82,9 +82,11 @@ namespace ops = paddle::operators;
REGISTER_OP_NPU_KERNEL( REGISTER_OP_NPU_KERNEL(
lookup_table_v2, lookup_table_v2,
ops::LookupTableV2NPUKernel<paddle::platform::NPUDeviceContext, float>, ops::LookupTableV2NPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::LookupTableV2NPUKernel<paddle::platform::NPUDeviceContext, int>,
ops::LookupTableV2NPUKernel<paddle::platform::NPUDeviceContext, ops::LookupTableV2NPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>); paddle::platform::float16>);
REGISTER_OP_NPU_KERNEL( REGISTER_OP_NPU_KERNEL(
lookup_table_v2_grad, ops::LookupTableV2GradNPUKernel<float>, lookup_table_v2_grad, ops::LookupTableV2GradNPUKernel<float>,
ops::LookupTableV2GradNPUKernel<int>,
ops::LookupTableV2GradNPUKernel<paddle::platform::float16>); ops::LookupTableV2GradNPUKernel<paddle::platform::float16>);
...@@ -124,11 +124,13 @@ namespace ops = paddle::operators; ...@@ -124,11 +124,13 @@ namespace ops = paddle::operators;
REGISTER_OP_NPU_KERNEL( REGISTER_OP_NPU_KERNEL(
slice, ops::SliceNPUKernel<paddle::platform::NPUDeviceContext, float>, slice, ops::SliceNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::SliceNPUKernel<paddle::platform::NPUDeviceContext, int>,
ops::SliceNPUKernel<paddle::platform::NPUDeviceContext, ops::SliceNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>); paddle::platform::float16>);
REGISTER_OP_NPU_KERNEL( REGISTER_OP_NPU_KERNEL(
slice_grad, slice_grad,
ops::SliceGradNPUKernel<paddle::platform::NPUDeviceContext, float>, ops::SliceGradNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::SliceGradNPUKernel<paddle::platform::NPUDeviceContext, int>,
ops::SliceGradNPUKernel<paddle::platform::NPUDeviceContext, ops::SliceGradNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>); paddle::platform::float16>);
...@@ -103,6 +103,8 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -103,6 +103,8 @@ class ShardingOptimizer(MetaOptimizerBase):
self.pp_bz = self.user_defined_strategy.sharding_configs["pp_bz"] self.pp_bz = self.user_defined_strategy.sharding_configs["pp_bz"]
self.pp_allreduce_in_optimize = self.user_defined_strategy.sharding_configs[ self.pp_allreduce_in_optimize = self.user_defined_strategy.sharding_configs[
"pp_allreduce_in_optimize"] "pp_allreduce_in_optimize"]
self.optimize_offload = self.user_defined_strategy.sharding_configs[
"optimize_offload"]
if self.inner_opt is None: if self.inner_opt is None:
raise ValueError( raise ValueError(
...@@ -238,7 +240,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -238,7 +240,7 @@ class ShardingOptimizer(MetaOptimizerBase):
#check_allreduce_sum(main_block, self._shard, self.sharding_ring_id, #check_allreduce_sum(main_block, self._shard, self.sharding_ring_id,
# self.dp_ring_id) # self.dp_ring_id)
#check_allreduce_sum(main_block, self._shard, self.dp_ring_id) #check_allreduce_sum(main_block, self._shard, self.dp_ring_id)
self._wait() # self._wait()
return optimize_ops, params_grads return optimize_ops, params_grads
def _set_up(self, params_grads): def _set_up(self, params_grads):
......
...@@ -424,7 +424,7 @@ class Section(DeviceWorker): ...@@ -424,7 +424,7 @@ class Section(DeviceWorker):
# cfg.program_desc.CopyFrom(program.program._get_desc()) # cfg.program_desc.CopyFrom(program.program._get_desc())
place = pipeline_opt["place"] place = pipeline_opt["place"]
place_id = pipeline_opt["place_id"] place_id = pipeline_opt["place_id"]
assert isinstance(place, core.CUDAPlace) # assert isinstance(place, core.CUDAPlace)
cfg.place = cfg.CUDAPlace cfg.place = cfg.CUDAPlace
cfg.place_id = place_id cfg.place_id = place_id
......
...@@ -5272,7 +5272,10 @@ class PipelineOptimizer(object): ...@@ -5272,7 +5272,10 @@ class PipelineOptimizer(object):
place_list = [] place_list = []
for dev in device_list: for dev in device_list:
dev_index = int(dev.split(":")[1]) dev_index = int(dev.split(":")[1])
place_list.append(core.CUDAPlace(dev_index % 8)) if core.is_compiled_with_cuda():
place_list.append(core.CUDAPlace(dev_index % 1))
elif core.is_compiled_with_npu():
place_list.append(core.NPUPlace(dev_index % 1))
# Step6: Split startup program # Step6: Split startup program
new_startup_program = self._split_startup_program(startup_program, new_startup_program = self._split_startup_program(startup_program,
...@@ -5295,7 +5298,10 @@ class PipelineOptimizer(object): ...@@ -5295,7 +5298,10 @@ class PipelineOptimizer(object):
self._accumulate_gradients(real_block) self._accumulate_gradients(real_block)
real_block._sync_with_cpp() real_block._sync_with_cpp()
place_id = int(os.getenv("FLAGS_selected_gpus", "0")) if core.is_compiled_with_cuda():
place_id = int(os.getenv("FLAGS_selected_gpus", "0"))
elif core.is_compiled_with_npu():
place_id = int(os.getenv("FLAGS_selected_npus", "0"))
main_program._pipeline_opt = { main_program._pipeline_opt = {
"trainer": "PipelineTrainer", "trainer": "PipelineTrainer",
"device_worker": "Section", "device_worker": "Section",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册