提交 a4586d17 编写于 作者: P phlrain

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into add_some_yaml_config

此差异已折叠。
......@@ -13,6 +13,7 @@
// limitations under the License.
#pragma once
#include <thrust/host_vector.h>
#include "heter_comm.h"
#include "paddle/fluid/distributed/ps/table/common_graph_table.h"
#include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h"
......@@ -40,11 +41,13 @@ class GpuPsGraphTable : public HeterComm<int64_t, int, int> {
int sample_size, int len);
NodeQueryResult *query_node_list(int gpu_id, int start, int query_size);
void clear_graph_info();
void move_neighbor_sample_result_to_source_gpu(int gpu_id, int gpu_num,
int sample_size, int *h_left,
int *h_right,
int64_t *src_sample_res,
int *actual_sample_size);
void move_neighbor_sample_result_to_source_gpu(
int gpu_id, int gpu_num, int *h_left, int *h_right,
int64_t *src_sample_res, thrust::host_vector<int> &total_sample_size);
void move_neighbor_sample_size_to_source_gpu(int gpu_id, int gpu_num,
int *h_left, int *h_right,
int *actual_sample_size,
int *total_sample_size);
int init_cpu_table(const paddle::distributed::GraphParameter &graph);
int load(const std::string &path, const std::string &param);
virtual int32_t end_graph_sampling() {
......
......@@ -13,10 +13,23 @@
// limitations under the License.
#pragma once
#include <cuda_runtime.h>
#include <curand_kernel.h>
#include <thrust/copy.h>
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <thrust/reduce.h>
#include <thrust/scan.h>
#include <thrust/transform.h>
#ifdef PADDLE_WITH_HETERPS
//#include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h"
namespace paddle {
namespace framework {
constexpr int WARP_SIZE = 32;
/*
comment 0
this kernel just serves as an example of how to sample nodes' neighbors.
......@@ -29,20 +42,79 @@ sample_size;
*/
__global__ void neighbor_sample_example(GpuPsCommGraph graph, int* index,
int* actual_size,
int64_t* sample_result, int sample_size,
int len) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len) {
struct MaxFunctor {
int sample_size;
HOSTDEVICE explicit inline MaxFunctor(int sample_size) {
this->sample_size = sample_size;
}
HOSTDEVICE inline int operator()(int x) const {
if (x > sample_size) {
return sample_size;
}
return x;
}
};
struct DegreeFunctor {
GpuPsCommGraph graph;
HOSTDEVICE explicit inline DegreeFunctor(GpuPsCommGraph graph) {
this->graph = graph;
}
HOSTDEVICE inline int operator()(int i) const {
return graph.node_list[i].neighbor_size;
}
};
template <int BLOCK_WARPS, int TILE_SIZE>
__global__ void neighbor_sample(const uint64_t rand_seed, GpuPsCommGraph graph,
int sample_size, int* index, int len,
int64_t* sample_result, int* output_idx,
int* output_offset) {
assert(blockDim.x == WARP_SIZE);
assert(blockDim.y == BLOCK_WARPS);
int i = blockIdx.x * TILE_SIZE + threadIdx.y;
const int last_idx = min(static_cast<int>(blockIdx.x + 1) * TILE_SIZE, len);
curandState rng;
curand_init(rand_seed * gridDim.x + blockIdx.x,
threadIdx.y * WARP_SIZE + threadIdx.x, 0, &rng);
while (i < last_idx) {
auto node_index = index[i];
actual_size[i] = graph.node_list[node_index].neighbor_size < sample_size
? graph.node_list[node_index].neighbor_size
: sample_size;
int offset = graph.node_list[node_index].neighbor_offset;
for (int j = 0; j < actual_size[i]; j++) {
sample_result[sample_size * i + j] = graph.neighbor_list[offset + j];
int degree = graph.node_list[node_index].neighbor_size;
const int offset = graph.node_list[node_index].neighbor_offset;
int output_start = output_offset[i];
if (degree <= sample_size) {
// Just copy
for (int j = threadIdx.x; j < degree; j += WARP_SIZE) {
sample_result[output_start + j] = graph.neighbor_list[offset + j];
}
} else {
for (int j = threadIdx.x; j < degree; j += WARP_SIZE) {
output_idx[output_start + j] = j;
}
__syncwarp();
for (int j = sample_size + threadIdx.x; j < degree; j += WARP_SIZE) {
const int num = curand(&rng) % (j + 1);
if (num < sample_size) {
atomicMax(
reinterpret_cast<unsigned int*>(output_idx + output_start + num),
static_cast<unsigned int>(j));
}
}
__syncwarp();
for (int j = threadIdx.x; j < sample_size; j += WARP_SIZE) {
const int perm_idx = output_idx[output_start + j] + offset;
sample_result[output_start + j] = graph.neighbor_list[perm_idx];
}
}
i += BLOCK_WARPS;
}
}
......@@ -79,7 +151,7 @@ int GpuPsGraphTable::load(const std::string& path, const std::string& param) {
gpu i triggers a neighbor_sample task,
when this task is done,
this function is called to move the sample result on other gpu back
to gup i and aggragate the result.
to gpu i and aggragate the result.
the sample_result is saved on src_sample_res and the actual sample size for
each node is saved on actual_sample_size.
the number of actual sample_result for
......@@ -96,10 +168,50 @@ int GpuPsGraphTable::load(const std::string& path, const std::string& param) {
that's what fill_dvals does.
*/
void GpuPsGraphTable::move_neighbor_sample_size_to_source_gpu(
int gpu_id, int gpu_num, int* h_left, int* h_right, int* actual_sample_size,
int* total_sample_size) {
// This function copyed actual_sample_size to source_gpu,
// and calculate total_sample_size of each gpu sample number.
for (int i = 0; i < gpu_num; i++) {
if (h_left[i] == -1 || h_right[i] == -1) {
continue;
}
auto shard_len = h_right[i] - h_left[i] + 1;
auto& node = path_[gpu_id][i].nodes_.front();
cudaMemcpyAsync(reinterpret_cast<char*>(actual_sample_size + h_left[i]),
node.val_storage + sizeof(int) * shard_len,
sizeof(int) * shard_len, cudaMemcpyDefault,
node.out_stream);
}
for (int i = 0; i < gpu_num; ++i) {
if (h_left[i] == -1 || h_right[i] == -1) {
total_sample_size[i] = 0;
continue;
}
auto& node = path_[gpu_id][i].nodes_.front();
cudaStreamSynchronize(node.out_stream);
auto shard_len = h_right[i] - h_left[i] + 1;
thrust::device_vector<int> t_actual_sample_size(shard_len);
thrust::copy(actual_sample_size + h_left[i],
actual_sample_size + h_left[i] + shard_len,
t_actual_sample_size.begin());
total_sample_size[i] = thrust::reduce(t_actual_sample_size.begin(),
t_actual_sample_size.end());
}
}
void GpuPsGraphTable::move_neighbor_sample_result_to_source_gpu(
int gpu_id, int gpu_num, int sample_size, int* h_left, int* h_right,
int64_t* src_sample_res, int* actual_sample_size) {
int gpu_id, int gpu_num, int* h_left, int* h_right, int64_t* src_sample_res,
thrust::host_vector<int>& total_sample_size) {
/*
if total_sample_size is [4, 5, 1, 6],
then cumsum_total_sample_size is [0, 4, 9, 10];
*/
thrust::host_vector<int> cumsum_total_sample_size(gpu_num, 0);
thrust::exclusive_scan(total_sample_size.begin(), total_sample_size.end(),
cumsum_total_sample_size.begin(), 0);
for (int i = 0; i < gpu_num; i++) {
if (h_left[i] == -1 || h_right[i] == -1) {
continue;
......@@ -109,14 +221,10 @@ void GpuPsGraphTable::move_neighbor_sample_result_to_source_gpu(
// auto& node = path_[gpu_id][i].nodes_[cur_step];
auto& node = path_[gpu_id][i].nodes_.front();
cudaMemcpyAsync(
reinterpret_cast<char*>(src_sample_res + h_left[i] * sample_size),
reinterpret_cast<char*>(src_sample_res + cumsum_total_sample_size[i]),
node.val_storage + sizeof(int64_t) * shard_len,
node.val_bytes_len - sizeof(int64_t) * shard_len, cudaMemcpyDefault,
sizeof(int64_t) * total_sample_size[i], cudaMemcpyDefault,
node.out_stream);
cudaMemcpyAsync(reinterpret_cast<char*>(actual_sample_size + h_left[i]),
node.val_storage + sizeof(int) * shard_len,
sizeof(int) * shard_len, cudaMemcpyDefault,
node.out_stream);
}
for (int i = 0; i < gpu_num; ++i) {
if (h_left[i] == -1 || h_right[i] == -1) {
......@@ -131,17 +239,35 @@ void GpuPsGraphTable::move_neighbor_sample_result_to_source_gpu(
TODO:
how to optimize it to eliminate the for loop
*/
__global__ void fill_dvalues(int64_t* d_shard_vals, int64_t* d_vals,
int* d_shard_actual_sample_size,
int* d_actual_sample_size, int* idx,
int sample_size, int len) {
__global__ void fill_dvalues_actual_sample_size(int* d_shard_actual_sample_size,
int* d_actual_sample_size,
int* idx, int len) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len) {
d_actual_sample_size[idx[i]] = d_shard_actual_sample_size[i];
// d_vals[idx[i]] = d_shard_vals[i];
for (int j = 0; j < sample_size; j++) {
d_vals[idx[i] * sample_size + j] = d_shard_vals[i * sample_size + j];
}
}
template <int BLOCK_WARPS, int TILE_SIZE>
__global__ void fill_dvalues_sample_result(int64_t* d_shard_vals,
int64_t* d_vals,
int* d_actual_sample_size, int* idx,
int* offset, int* d_offset,
int len) {
assert(blockDim.x == WARP_SIZE);
assert(blockDim.y == BLOCK_WARPS);
int i = blockIdx.x * TILE_SIZE + threadIdx.y;
const int last_idx = min(static_cast<int>(blockIdx.x + 1) * TILE_SIZE, len);
while (i < last_idx) {
const int sample_size = d_actual_sample_size[idx[i]];
for (int j = threadIdx.x; j < sample_size; j += WARP_SIZE) {
d_vals[offset[idx[i]] + j] = d_shard_vals[d_offset[i] + j];
}
#ifdef PADDLE_WITH_CUDA
__syncwarp();
#endif
i += BLOCK_WARPS;
}
}
......@@ -255,14 +381,12 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id,
h_left = [0,5],h_right = [4,8]
*/
NeighborSampleResult* result = new NeighborSampleResult(sample_size, len);
if (len == 0) {
return result;
}
cudaMalloc((void**)&result->val, len * sample_size * sizeof(int64_t));
cudaMalloc((void**)&result->actual_sample_size, len * sizeof(int));
int* actual_sample_size = result->actual_sample_size;
int64_t* val = result->val;
int total_gpu = resource_->total_gpu();
int dev_id = resource_->dev_id(gpu_id);
platform::CUDAPlace place = platform::CUDAPlace(dev_id);
......@@ -287,11 +411,6 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id,
auto d_shard_keys = memory::Alloc(place, len * sizeof(int64_t));
int64_t* d_shard_keys_ptr = reinterpret_cast<int64_t*>(d_shard_keys->ptr());
auto d_shard_vals = memory::Alloc(place, sample_size * len * sizeof(int64_t));
int64_t* d_shard_vals_ptr = reinterpret_cast<int64_t*>(d_shard_vals->ptr());
auto d_shard_actual_sample_size = memory::Alloc(place, len * sizeof(int));
int* d_shard_actual_sample_size_ptr =
reinterpret_cast<int*>(d_shard_actual_sample_size->ptr());
split_input_to_shard(key, d_idx_ptr, len, d_left_ptr, d_right_ptr, gpu_id);
......@@ -331,6 +450,7 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id,
of alloc_mem_i, actual_sample_size_of_x equals ((int
*)alloc_mem_i)[shard_len + x]
*/
create_storage(gpu_id, i, shard_len * sizeof(int64_t),
shard_len * (1 + sample_size) * sizeof(int64_t));
}
......@@ -351,6 +471,7 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id,
h_right[i] - h_left[i] + 1,
resource_->remote_stream(i, gpu_id));
}
for (int i = 0; i < total_gpu; ++i) {
if (h_left[i] == -1) {
continue;
......@@ -364,10 +485,42 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id,
int* res_array = reinterpret_cast<int*>(node.val_storage);
int* actual_size_array = res_array + shard_len;
int64_t* sample_array = (int64_t*)(res_array + shard_len * 2);
neighbor_sample_example<<<grid_size, block_size_, 0,
resource_->remote_stream(i, gpu_id)>>>(
graph, res_array, actual_size_array, sample_array, sample_size,
shard_len);
// 1. get actual_size_array.
// 2. get sum of actual_size.
// 3. get offset ptr
thrust::device_vector<int> t_res_array(shard_len);
thrust::copy(res_array, res_array + shard_len, t_res_array.begin());
thrust::device_vector<int> t_actual_size_array(shard_len);
thrust::transform(t_res_array.begin(), t_res_array.end(),
t_actual_size_array.begin(), DegreeFunctor(graph));
if (sample_size >= 0) {
thrust::transform(t_actual_size_array.begin(), t_actual_size_array.end(),
t_actual_size_array.begin(), MaxFunctor(sample_size));
}
thrust::copy(t_actual_size_array.begin(), t_actual_size_array.end(),
actual_size_array);
int total_sample_sum =
thrust::reduce(t_actual_size_array.begin(), t_actual_size_array.end());
thrust::device_vector<int> output_idx(total_sample_sum);
thrust::device_vector<int> output_offset(shard_len);
thrust::exclusive_scan(t_actual_size_array.begin(),
t_actual_size_array.end(), output_offset.begin(), 0);
constexpr int BLOCK_WARPS = 128 / WARP_SIZE;
constexpr int TILE_SIZE = BLOCK_WARPS * 16;
const dim3 block_(WARP_SIZE, BLOCK_WARPS);
const dim3 grid_((shard_len + TILE_SIZE - 1) / TILE_SIZE);
neighbor_sample<
BLOCK_WARPS,
TILE_SIZE><<<grid_, block_, 0, resource_->remote_stream(i, gpu_id)>>>(
0, graph, sample_size, res_array, shard_len, sample_array,
thrust::raw_pointer_cast(output_idx.data()),
thrust::raw_pointer_cast(output_offset.data()));
}
for (int i = 0; i < total_gpu; ++i) {
......@@ -378,13 +531,56 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id,
tables_[i]->rwlock_->UNLock();
}
// walk_to_src(num, total_gpu, h_left, h_right, d_shard_vals_ptr);
move_neighbor_sample_result_to_source_gpu(gpu_id, total_gpu, sample_size,
h_left, h_right, d_shard_vals_ptr,
d_shard_actual_sample_size_ptr);
fill_dvalues<<<grid_size, block_size_, 0, stream>>>(
d_shard_vals_ptr, val, d_shard_actual_sample_size_ptr, actual_sample_size,
d_idx_ptr, sample_size, len);
auto d_shard_actual_sample_size = memory::Alloc(place, len * sizeof(int));
int* d_shard_actual_sample_size_ptr =
reinterpret_cast<int*>(d_shard_actual_sample_size->ptr());
// Store total sample number of each gpu.
thrust::host_vector<int> d_shard_total_sample_size(total_gpu, 0);
move_neighbor_sample_size_to_source_gpu(
gpu_id, total_gpu, h_left, h_right, d_shard_actual_sample_size_ptr,
thrust::raw_pointer_cast(d_shard_total_sample_size.data()));
int allocate_sample_num = 0;
for (int i = 0; i < total_gpu; ++i) {
allocate_sample_num += d_shard_total_sample_size[i];
}
auto d_shard_vals =
memory::Alloc(place, allocate_sample_num * sizeof(int64_t));
int64_t* d_shard_vals_ptr = reinterpret_cast<int64_t*>(d_shard_vals->ptr());
move_neighbor_sample_result_to_source_gpu(gpu_id, total_gpu, h_left, h_right,
d_shard_vals_ptr,
d_shard_total_sample_size);
cudaMalloc((void**)&result->val, allocate_sample_num * sizeof(int64_t));
cudaMalloc((void**)&result->actual_sample_size, len * sizeof(int));
cudaMalloc((void**)&result->offset, len * sizeof(int));
int64_t* val = result->val;
int* actual_sample_size = result->actual_sample_size;
int* offset = result->offset;
fill_dvalues_actual_sample_size<<<grid_size, block_size_, 0, stream>>>(
d_shard_actual_sample_size_ptr, actual_sample_size, d_idx_ptr, len);
thrust::device_vector<int> t_actual_sample_size(len);
thrust::copy(actual_sample_size, actual_sample_size + len,
t_actual_sample_size.begin());
thrust::exclusive_scan(t_actual_sample_size.begin(),
t_actual_sample_size.end(), offset, 0);
int* d_offset;
cudaMalloc(&d_offset, len * sizeof(int));
thrust::copy(d_shard_actual_sample_size_ptr,
d_shard_actual_sample_size_ptr + len,
t_actual_sample_size.begin());
thrust::exclusive_scan(t_actual_sample_size.begin(),
t_actual_sample_size.end(), d_offset, 0);
constexpr int BLOCK_WARPS_ = 128 / WARP_SIZE;
constexpr int TILE_SIZE_ = BLOCK_WARPS_ * 16;
const dim3 block__(WARP_SIZE, BLOCK_WARPS_);
const dim3 grid__((len + TILE_SIZE_ - 1) / TILE_SIZE_);
fill_dvalues_sample_result<BLOCK_WARPS_,
TILE_SIZE_><<<grid__, block__, 0, stream>>>(
d_shard_vals_ptr, val, actual_sample_size, d_idx_ptr, offset, d_offset,
len);
cudaStreamSynchronize(stream);
for (int i = 0; i < total_gpu; ++i) {
int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1;
......@@ -393,6 +589,7 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id,
}
destroy_storage(gpu_id, i);
}
cudaFree(d_offset);
return result;
}
......
......@@ -94,19 +94,44 @@ TEST(TEST_FLEET, graph_comm) {
0 --index--->0
7 --index-->2
*/
int64_t cpu_key[3] = {7, 0, 6};
void *key;
cudaMalloc((void **)&key, 3 * sizeof(int64_t));
cudaMemcpy(key, cpu_key, 3 * sizeof(int64_t), cudaMemcpyHostToDevice);
auto neighbor_sample_res = g.graph_neighbor_sample(0, (int64_t *)key, 3, 3);
res = new int64_t[9];
cudaMemcpy(res, neighbor_sample_res->val, 72, cudaMemcpyDeviceToHost);
int64_t expected_sample_val[] = {28, 29, 30, 0, -1, -1, 21, 22, 23};
for (int i = 0; i < 9; i++) {
if (expected_sample_val[i] != -1) {
ASSERT_EQ(res[i], expected_sample_val[i]);
res = new int64_t[7];
cudaMemcpy(res, neighbor_sample_res->val, 56, cudaMemcpyDeviceToHost);
int *actual_sample_size = new int[3];
cudaMemcpy(actual_sample_size, neighbor_sample_res->actual_sample_size, 12,
cudaMemcpyDeviceToHost); // 3, 1, 3
int *cumsum_sample_size = new int[3];
cudaMemcpy(cumsum_sample_size, neighbor_sample_res->offset, 12,
cudaMemcpyDeviceToHost); // 0, 3, 4
std::vector<std::vector<int64_t>> neighbors_;
std::vector<int64_t> neighbors_7 = {28, 29, 30, 31, 32, 33, 34, 35};
std::vector<int64_t> neighbors_0 = {0};
std::vector<int64_t> neighbors_6 = {21, 22, 23, 24, 25, 26, 27};
neighbors_.push_back(neighbors_7);
neighbors_.push_back(neighbors_0);
neighbors_.push_back(neighbors_6);
for (int i = 0; i < 3; i++) {
for (int j = cumsum_sample_size[i];
j < cumsum_sample_size[i] + actual_sample_size[i]; j++) {
bool flag = false;
for (int k = 0; k < neighbors_[i].size(); k++) {
if (res[j] == neighbors_[i][k]) {
flag = true;
break;
}
}
ASSERT_EQ(flag, true);
}
}
delete[] res;
delete[] actual_sample_size;
delete[] cumsum_sample_size;
delete neighbor_sample_res;
}
......@@ -25,14 +25,14 @@ std::set<std::string> ignored_ops = {
"sum",
"clip",
"clip_by_norm",
"square",
"reduce_sum",
"sqrt",
"elementwise_max",
"elementwise_div",
"elementwise_mul",
"scale", // adamax
"assign", // adamw
"scale", // adamax
"assign", // adamw
"squared_l2_norm" // gradient_clip_norm
};
const bool startswith(const std::string& str, const std::string& pre) {
......@@ -62,6 +62,10 @@ void IpuOptimizerExtractPass::ApplyImpl(ir::Graph* graph) const {
new_op.SetAttr("with_lr_sched", false);
std::set<std::string> set_ops{};
// save the weight decay tensor_name and weight_decay_value for Lamb
std::vector<std::string> weight_decay_vars{};
std::vector<float> weight_decay_values{};
// use map store <op_type, op_ptr> ?
for (auto* node : graph->Nodes()) {
if (!node->IsOp()) {
......@@ -75,6 +79,15 @@ void IpuOptimizerExtractPass::ApplyImpl(ir::Graph* graph) const {
auto op_role = static_cast<OpRole>(op_role_);
if (op_role == OpRole::kOptimize) {
// save weight decay value from every lamb optimizer op
if (op_type == "lamb" && op->HasAttr("weight_decay")) {
auto weight_decay_value =
BOOST_GET_CONST(float, op->GetAttr("weight_decay"));
auto params = op->Output("ParamOut");
weight_decay_vars.push_back(params[0]);
weight_decay_values.push_back(weight_decay_value);
}
if (set_ops.count(op_type)) {
continue;
}
......@@ -270,7 +283,10 @@ void IpuOptimizerExtractPass::ApplyImpl(ir::Graph* graph) const {
// seems with_lr_sched is always true
new_op.SetAttr("with_lr_sched", true);
// setup weight deacy
// setup weight decay for Lamb
new_op.SetAttr("weight_decay_vars", weight_decay_vars);
new_op.SetAttr("weight_decay_values", weight_decay_values);
// weight_decay/coeff is "scale" attr of scale_op
if (set_ops.count("scale") && set_ops.count("sum")) {
if (set_ops.count("sign")) {
......
......@@ -30,7 +30,8 @@ void TransferCastOpPass::ApplyImpl(ir::Graph* graph) const {
auto ipu_backend = platform::ipu::IpuBackend::GetInstance();
auto enable_fp16 = ipu_backend->GetIpuStrategy()->enable_fp16;
if (enable_fp16) {
auto transfer_cast_op = ipu_backend->GetIpuStrategy()->transfer_cast_op;
if (enable_fp16 && transfer_cast_op) {
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == "popart_cast") {
if (BOOST_GET_CONST(std::string, node->Op()->GetAttr("to")) ==
......
......@@ -28,7 +28,7 @@
USE_OP_ITSELF(batch_norm);
USE_OP_DEVICE_KERNEL(batch_norm, MKLDNN);
USE_OP(conv2d_transpose);
USE_OP_ITSELF(conv2d_transpose);
USE_OP_DEVICE_KERNEL(conv2d_transpose, MKLDNN);
USE_OP_ITSELF(elementwise_add);
USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN);
......
......@@ -79,18 +79,6 @@ void TensorCopyImpl(const TENSOR& src, const platform::Place& dst_place,
if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) {
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
}
#ifdef PADDLE_WITH_IPU
else if (platform::is_ipu_place(src_place) && // NOLINT
platform::is_cpu_place(dst_place)) {
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
} else if (platform::is_cpu_place(src_place) &&
platform::is_ipu_place(dst_place)) {
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
} else if (platform::is_ipu_place(src_place) &&
platform::is_ipu_place(dst_place)) {
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
}
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
else if (platform::is_custom_place(src_place) && // NOLINT
platform::is_cpu_place(dst_place)) {
......@@ -390,6 +378,29 @@ void TensorCopyImpl(const TENSOR& src, const platform::Place& dst_place,
"Copying from %s to %s is not supported.", src_place, dst_place));
}
#endif
#ifdef PADDLE_WITH_IPU
else if (platform::is_ipu_place(src_place) && // NOLINT
platform::is_cpu_place(dst_place)) {
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
}
else if (platform::is_cpu_place(src_place) && // NOLINT
platform::is_ipu_place(dst_place)) {
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
}
else if (platform::is_ipu_place(src_place) && // NOLINT
platform::is_ipu_place(dst_place)) {
if (src_ptr == dst_ptr) {
VLOG(3) << "Skip copy the same data sync from " << src_place << " to "
<< dst_place;
return;
}
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
}
else { // NOLINT
PADDLE_THROW(platform::errors::Unimplemented(
"Copying from %s to %s is not supported.", src_place, dst_place));
}
#endif
}
template <typename TENSOR>
......@@ -447,27 +458,15 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place,
if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) {
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
}
#ifdef PADDLE_WITH_IPU
else if (platform::is_ipu_place(src_place) && // NOLINT
platform::is_cpu_place(dst_place)) {
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
} else if (platform::is_cpu_place(src_place) && // NOLINT
platform::is_ipu_place(dst_place)) {
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
} else { // NOLINT
PADDLE_THROW(platform::errors::Unimplemented(
"Copy from %s to %s is not supported.", src_place, dst_place));
}
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
else if (platform::is_custom_place(src_place) && // NOLINT
platform::is_cpu_place(dst_place)) { /* custom_device -> cpu*/
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size, nullptr);
}
} // NOLINT
else if (platform::is_cpu_place(src_place) && // NOLINT
platform::is_custom_place(dst_place)) { /* cpu -> custom_device*/
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size, nullptr);
}
} // NOLINT
else if (platform::is_custom_place(src_place) && // NOLINT
platform::is_custom_place(
dst_place)) { /* custom_device -> custom_device*/
......@@ -483,11 +482,11 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place,
else if (platform::is_xpu_place(src_place) && // NOLINT
platform::is_cpu_place(dst_place)) {
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
}
} // NOLINT
else if (platform::is_cpu_place(src_place) && // NOLINT
platform::is_xpu_place(dst_place)) {
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
}
} // NOLINT
else if (platform::is_xpu_place(src_place) && // NOLINT
platform::is_xpu_place(dst_place)) {
if (src_ptr == dst_ptr) {
......@@ -502,7 +501,7 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place,
auto xpu_ctx = platform::DeviceContextPool::Instance().Get(xpu_dst_place);
xpu_ctx->Wait();
}
}
} // NOLINT
else { // NOLINT
PADDLE_THROW(platform::errors::Unimplemented(
"Copy from %s to %s is not supported.", src_place, dst_place));
......@@ -601,6 +600,29 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place,
"Copy from %s to %s is not supported.", src_place, dst_place));
}
#endif
#ifdef PADDLE_WITH_IPU
else if (platform::is_ipu_place(src_place) && // NOLINT
platform::is_cpu_place(dst_place)) {
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
}
else if (platform::is_cpu_place(src_place) && // NOLINT
platform::is_ipu_place(dst_place)) {
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
}
else if (platform::is_ipu_place(src_place) && // NOLINT
platform::is_ipu_place(dst_place)) {
if (src_ptr == dst_ptr) {
VLOG(3) << "Skip copy the same data sync from " << src_place << " to "
<< dst_place;
return;
}
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
}
else { // NOLINT
PADDLE_THROW(platform::errors::Unimplemented(
"Copy from %s to %s is not supported.", src_place, dst_place));
}
#endif
}
template <typename Predicate, typename DevCtx>
......
......@@ -1109,8 +1109,9 @@ void Reducer::FinalizeBackward() {
if (find_unused_vars_each_step_) {
// TODO(liuyuhui) support xpu about Tensorcopy/TensorFromVector/TensorToVector
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_GLOO) || defined(PADDLE_WITH_ASCEND_CL)
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_GLOO) || defined(PADDLE_WITH_ASCEND_CL) || \
defined(PADDLE_WITH_CNCL)
ProcessUnusedDenseVars();
#endif
// Initialize local used vars
......
......@@ -17,7 +17,7 @@ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
USE_OP_ITSELF(conv2d);
USE_OP(conv2d_transpose);
USE_OP_ITSELF(conv2d_transpose);
namespace paddle {
namespace inference {
......
......@@ -40,6 +40,13 @@ class FeedVariableVisitor : public boost::static_visitor<void> {
out_var_->GetMutable<framework::LoDTensor>();
if (platform::is_same_place(in_tensor.place(), place_)) {
out_tensor->ShareDataWith(in_tensor);
#ifdef PADDLE_WITH_IPU
} else if (platform::is_ipu_place(place_)) {
// For ipu, both in_tensor and out_tensor are allocated on cpu,
// PopART will copy tensor from host automatically,
// no TensorCopy() is required here.
out_tensor->ShareDataWith(in_tensor);
#endif
} else {
platform::DeviceContext *context =
platform::DeviceContextPool::Instance().Get(place_);
......
......@@ -19,14 +19,16 @@ namespace operators {
template <typename DeviceContext, typename T>
class GemmConvXPUKernel : public framework::OpKernel<T> {
using XPUT = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("Input");
void Compute(const framework::ExecutionContext &context) const override {
const Tensor *input = context.Input<Tensor>("Input");
// The filter will be reshaped in the calculations,
// so here use an assignment operation,
// that avoids modifying the variable in the Scope.
Tensor filter = *context.Input<Tensor>("Filter");
Tensor* output = context.Output<Tensor>("Output");
Tensor *output = context.Output<Tensor>("Output");
output->mutable_data<T>(context.GetPlace());
int groups = context.Attr<int>("groups");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
......@@ -53,11 +55,16 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
const int img_h = static_cast<int>(input->dims()[2]);
const int img_w = static_cast<int>(input->dims()[3]);
const int f = static_cast<int>(filter.dims()[0]);
auto& dev_ctx = context.template device_context<DeviceContext>();
int r = xpu::conv2d<float, float, float, int16_t>(
dev_ctx.x_context(), input->data<float>(), filter.data<float>(),
output->data<float>(), batch_size, img_c, img_h, img_w, f, ksize,
strides, paddings, dilations, groups, nullptr, nullptr, nullptr, true);
const XPUT *input_data = reinterpret_cast<const XPUT *>(input->data<T>());
const XPUT *filter_data = reinterpret_cast<const XPUT *>(filter.data<T>());
XPUT *output_data = reinterpret_cast<XPUT *>(output->data<T>());
auto &dev_ctx = context.template device_context<DeviceContext>();
int r = xpu::conv2d<XPUT, XPUT, XPUT, int16_t>(
dev_ctx.x_context(), input_data, filter_data, output_data, batch_size,
img_c, img_h, img_w, f, ksize, strides, paddings, dilations, groups,
nullptr, nullptr, nullptr, true);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU conv kernel return wrong value[%d %s]",
......@@ -67,14 +74,16 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
template <typename DeviceContext, typename T>
class GemmConvGradXPUKernel : public framework::OpKernel<T> {
using XPUT = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("Input");
const Tensor* output_grad =
void Compute(const framework::ExecutionContext &context) const override {
const Tensor *input = context.Input<Tensor>("Input");
const Tensor *output_grad =
context.Input<Tensor>(framework::GradVarName("Output"));
Tensor* input_grad =
Tensor *input_grad =
context.Output<Tensor>(framework::GradVarName("Input"));
Tensor* filter_grad =
Tensor *filter_grad =
context.Output<Tensor>(framework::GradVarName("Filter"));
// The filter and filter_grad will be reshaped in the calculations,
// so here use an assignment operation,
......@@ -107,19 +116,27 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
const int img_h = static_cast<int>(input->dims()[2]);
const int img_w = static_cast<int>(input->dims()[3]);
const int f = static_cast<int>(filter.dims()[0]);
const XPUT *input_data = reinterpret_cast<const XPUT *>(input->data<T>());
const XPUT *filter_data = reinterpret_cast<const XPUT *>(filter.data<T>());
const XPUT *output_grad_data =
reinterpret_cast<const XPUT *>(output_grad->data<T>());
XPUT *input_grad_data = nullptr;
if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace());
input_grad_data = reinterpret_cast<XPUT *>(input_grad->data<T>());
}
XPUT *filter_grad_data = nullptr;
if (filter_grad) {
filter_grad->mutable_data<T>(context.GetPlace());
filter_grad_data = reinterpret_cast<XPUT *>(filter_grad->data<T>());
}
auto& dev_ctx = context.template device_context<DeviceContext>();
int r = xpu::conv2d_grad<float, float, float, int16_t>(
dev_ctx.x_context(), input->data<T>(), filter.data<T>(),
output_grad->data<T>(), input_grad ? input_grad->data<T>() : nullptr,
filter_grad ? filter_grad->data<T>() : nullptr, batch_size, img_c,
img_h, img_w, f, ksize, strides, paddings, dilations, groups, nullptr,
nullptr, nullptr, nullptr, nullptr, true);
auto &dev_ctx = context.template device_context<DeviceContext>();
int r = xpu::conv2d_grad<XPUT, XPUT, XPUT, int16_t>(
dev_ctx.x_context(), input_data, filter_data, output_grad_data,
input_grad_data, filter_grad_data, batch_size, img_c, img_h, img_w, f,
ksize, strides, paddings, dilations, groups, nullptr, nullptr, nullptr,
nullptr, nullptr, true);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU conv kernel return wrong value[%d %s]",
......@@ -130,14 +147,22 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
depthwise_conv2d,
ops::GemmConvXPUKernel<paddle::platform::XPUDeviceContext, float>);
REGISTER_OP_XPU_KERNEL(
conv2d, ops::GemmConvXPUKernel<paddle::platform::XPUDeviceContext, float>);
conv2d, ops::GemmConvXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::GemmConvXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL(
conv2d_grad,
ops::GemmConvGradXPUKernel<paddle::platform::XPUDeviceContext, float>);
ops::GemmConvGradXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::GemmConvGradXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL(
depthwise_conv2d,
ops::GemmConvXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::GemmConvXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL(
depthwise_conv2d_grad,
ops::GemmConvGradXPUKernel<paddle::platform::XPUDeviceContext, float>);
ops::GemmConvGradXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::GemmConvGradXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
#endif
......@@ -13,13 +13,17 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/conv_transpose_op.h"
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/cudnn_workspace_helper.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/binary.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
......@@ -29,165 +33,6 @@ namespace operators {
using DataLayout = framework::DataLayout;
void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "ConvTranspose");
OP_INOUT_CHECK(ctx->HasInput("Filter"), "Input", "Filter", "ConvTranspose");
OP_INOUT_CHECK(ctx->HasOutput("Output"), "Output", "Output", "ConvTranspose");
auto in_dims = ctx->GetInputDim("Input");
auto filter_dims = ctx->GetInputDim("Filter");
std::vector<int> output_size =
ctx->Attrs().Get<std::vector<int>>("output_size");
std::vector<int> output_padding =
ctx->Attrs().Get<std::vector<int>>("output_padding");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
std::vector<int> dilations = ctx->Attrs().Get<std::vector<int>>("dilations");
int groups = ctx->Attrs().Get<int>("groups");
std::string padding_algorithm =
ctx->Attrs().Get<std::string>("padding_algorithm");
const std::string data_layout_str =
ctx->Attrs().Get<std::string>("data_format");
const DataLayout data_layout =
ctx->IsRunMKLDNNKernel() ? DataLayout::kNCHW
: framework::StringToDataLayout(data_layout_str);
PADDLE_ENFORCE_EQ(in_dims.size() == 4 || in_dims.size() == 5, true,
platform::errors::InvalidArgument(
"Input of Op(conv_transpose) should be 4-D or "
"5-D Tensor. But received: %u-D Tensor, "
"the shape of input is [%s]",
in_dims.size(), in_dims));
PADDLE_ENFORCE_EQ(
in_dims.size(), filter_dims.size(),
platform::errors::InvalidArgument(
"The input's dimension size and filter's dimension size of "
"Op (conv_transpose) should be equal. But received: the shape of "
"input is [%s], the dimension size of input is [%d], the shape "
"of filter is [%s], the dimension size of filter is [%d]. ",
in_dims, in_dims.size(), filter_dims, filter_dims.size()));
int stride_size = strides.size();
for (int i = 0; i < stride_size; ++i) {
PADDLE_ENFORCE_GT(
strides[i], 0,
platform::errors::InvalidArgument(
"The stride of Op(Conv) should be larget than 0, but received "
"stride is %d.",
strides[i]));
}
int in_sub_stride_size = in_dims.size() - stride_size;
PADDLE_ENFORCE_EQ(
in_dims.size() - strides.size(), 2U,
platform::errors::InvalidArgument(
"The input's dimension size minus Attr(stride)'s size must "
"be euqal to 2 for Op(conv_transpose). But received: [%d], the "
"input's dimension size is [%d], the shape of input "
"is [%s], the Attr(stride)'s size is [%d].",
in_sub_stride_size, in_dims.size(), in_dims, strides.size()));
if (output_size.size())
PADDLE_ENFORCE_EQ(
output_size.size(), strides.size(),
platform::errors::InvalidArgument(
"The Attr(output_size) and Attr(stride) of Op(conv_transpose) "
"should be the same."));
if (output_padding.size())
PADDLE_ENFORCE_EQ(
output_padding.size(), strides.size(),
platform::errors::InvalidArgument(
"The Attr(output_padding) and Attr(stride) of Op(conv_transpose) "
"should be the same."));
const int64_t C =
(data_layout != DataLayout::kNHWC ? in_dims[1]
: in_dims[in_dims.size() - 1]);
PADDLE_ENFORCE_EQ(
C, filter_dims[0],
platform::errors::InvalidArgument(
"The number of input channels should be equal to filter channels "
"for Op(conv_transpose). But received: the input's channels is "
"[%d], the shape of input is [%s], the filter's channels is [%d], "
"the shape of filter is [%s]. The data_format is %s."
"The error may come from wrong data_format setting.",
C, in_dims, filter_dims[0], filter_dims, data_layout_str));
framework::DDim in_data_dims;
if (data_layout != DataLayout::kNHWC) {
in_data_dims = phi::slice_ddim(in_dims, 2, in_dims.size());
} else {
in_data_dims = phi::slice_ddim(in_dims, 1, in_dims.size() - 1);
}
framework::DDim filter_data_dims =
phi::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);
std::vector<int64_t> output_shape({in_dims[0]});
if (data_layout != DataLayout::kNHWC) {
output_shape.push_back(filter_dims[1] * groups);
}
const int offset = (data_layout != DataLayout::kNHWC ? 2 : 1);
for (size_t i = 0; i < strides.size(); ++i) {
auto filter_extent = dilations[i] * (filter_dims[i + 2] - 1) + 1;
auto infer_shape = (ctx->IsRuntime() || in_dims[i + offset] > 0)
? (in_dims[i + offset] - 1) * strides[i] -
paddings[2 * i] - paddings[2 * i + 1] +
filter_extent
: -1;
if (output_size.size()) {
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_GE(
output_size[i], infer_shape,
platform::errors::InvalidArgument(
"output_size of Op(ConvTransposeOp) should not be "
"less than the infered output size. But received output_size = "
"[%s], whose dim %d is less than the infered output size [%s]",
phi::make_ddim(output_size).to_str(), i, infer_shape));
PADDLE_ENFORCE_LT(
output_size[i], infer_shape + strides[i],
platform::errors::InvalidArgument(
"output_size of Op(ConvTransposeOp) should be less "
"than infered size + stride. But received output_size = [%s], "
"whose dim %d is not less than the infered output size (%d) + "
"stride (%d) = %d",
phi::make_ddim(output_size).to_str(), i, infer_shape,
strides[i], infer_shape + strides[i]));
}
output_shape.push_back(output_size[i]);
} else if (output_padding.size()) {
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_GE(
output_padding[i], 0,
platform::errors::InvalidArgument(
"output_padding of Op(ConvTransposeOp) should not be "
"less than the 0. But received output_padding = "
"[%s], whose dim %d is less than 0",
phi::make_ddim(output_padding).to_str(), i));
PADDLE_ENFORCE_LT(
output_padding[i], std::max(strides[i], dilations[i]),
platform::errors::InvalidArgument(
"output_padding of Op(ConvTransposeOp) should be less "
"than either stride or dilation. But received output_size = "
"[%s], "
"whose dim %d is not less than either stride (%d) or "
"dilation (%d)",
phi::make_ddim(output_size).to_str(), i, strides[i],
dilations[i]));
}
output_shape.push_back((infer_shape + output_padding[i]));
} else {
output_shape.push_back(infer_shape);
}
}
if (data_layout == DataLayout::kNHWC) {
output_shape.push_back(filter_dims[1] * groups);
}
ctx->SetOutputDim("Output", phi::make_ddim(output_shape));
}
framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
framework::LibraryType library_{framework::LibraryType::kPlain};
......@@ -217,7 +62,7 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
}
framework::OpKernelType ConvTransposeOp::GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const std::string& var_name, const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const {
#ifdef PADDLE_WITH_MKLDNN
// Only input require reshaping, weights and
......@@ -493,17 +338,6 @@ Example:
)DOC");
}
void ConvTransposeOpGrad::InferShape(framework::InferShapeContext* ctx) const {
auto in_dims = ctx->GetInputDim("Input");
auto filter_dims = ctx->GetInputDim("Filter");
if (ctx->HasOutput(framework::GradVarName("Input"))) {
ctx->SetOutputDim(framework::GradVarName("Input"), in_dims);
}
if (ctx->HasOutput(framework::GradVarName("Filter"))) {
ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims);
}
}
framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
bool use_cudnn =
......@@ -587,24 +421,6 @@ class ConvTransposeDoubleGradMaker : public framework::SingleGradOpMaker<T> {
}
};
void ConvTransposeOpDoubleGrad::InferShape(
framework::InferShapeContext* ctx) const {
auto x_dims = ctx->GetInputDim("Input");
auto w_dims = ctx->GetInputDim("Filter");
auto do_dims = ctx->GetInputDim("DOutput");
if (ctx->HasOutput("DDOutput") &&
(ctx->HasInput("DDInput") || (ctx->HasInput("DDFilter")))) {
ctx->SetOutputDim("DDOutput", do_dims);
}
if (ctx->HasOutput("DFilter") && ctx->HasInput("DDInput")) {
ctx->SetOutputDim("DFilter", w_dims);
}
if (ctx->HasOutput("DInput") && ctx->HasInput("DDFilter")) {
ctx->SetOutputDim("DInput", x_dims);
}
}
framework::OpKernelType ConvTransposeOpDoubleGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
bool use_cudnn =
......@@ -635,59 +451,57 @@ framework::OpKernelType ConvTransposeOpDoubleGrad::GetExpectedKernelType(
namespace ops = paddle::operators;
// conv2d_transpose
DECLARE_INFER_SHAPE_FUNCTOR(conv2d_transpose, Conv2dTranposeInferShapeFunctor,
PD_INFER_META(phi::ConvTransposeInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(conv2d_transpose_grad,
Conv2dTranposeGradInferShapeFunctor,
PD_INFER_META(phi::ConvTransposeGradInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(
conv2d_transpose_grad_grad, Conv2dTranposeDoubleGradInferShapeFunctor,
PD_INFER_META(phi::Conv2dTransposeDoubleGradInferMeta));
REGISTER_OPERATOR(conv2d_transpose, ops::ConvTransposeOp,
ops::Conv2DTransposeOpMaker,
ops::ConvTransposeGradOpMaker<paddle::framework::OpDesc>,
ops::ConvTransposeGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(
conv2d_transpose_grad, ops::ConvTransposeOpGrad,
ops::ConvTransposeDoubleGradMaker<paddle::framework::OpDesc>,
ops::ConvTransposeDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(conv2d_transpose_grad_grad, ops::ConvTransposeOpDoubleGrad);
REGISTER_OP_CPU_KERNEL(
conv2d_transpose,
ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
conv2d_transpose_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext,
double>);
ops::ConvTransposeGradOpMaker<paddle::imperative::OpBase>,
Conv2dTranposeInferShapeFunctor);
REGISTER_OPERATOR(conv2d_transpose_grad, ops::ConvTransposeOpGrad,
ops::ConvTransposeDoubleGradMaker<paddle::framework::OpDesc>,
ops::ConvTransposeDoubleGradMaker<paddle::imperative::OpBase>,
Conv2dTranposeGradInferShapeFunctor);
REGISTER_OPERATOR(conv2d_transpose_grad_grad, ops::ConvTransposeOpDoubleGrad,
Conv2dTranposeDoubleGradInferShapeFunctor);
// conv3d_transpose
DECLARE_INFER_SHAPE_FUNCTOR(conv3d_transpose, Conv3dTranposeInferShapeFunctor,
PD_INFER_META(phi::ConvTransposeInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(conv3d_transpose_grad,
Conv3dTranposeGradInferShapeFunctor,
PD_INFER_META(phi::ConvTransposeGradInferMeta));
REGISTER_OPERATOR(conv3d_transpose, ops::ConvTransposeOp,
ops::Conv3DTransposeOpMaker,
ops::ConvTransposeGradOpMaker<paddle::framework::OpDesc>,
ops::ConvTransposeGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(conv3d_transpose_grad, ops::ConvTransposeOpGrad);
REGISTER_OP_CPU_KERNEL(
conv3d_transpose,
ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
conv3d_transpose_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext,
double>);
ops::ConvTransposeGradOpMaker<paddle::imperative::OpBase>,
Conv3dTranposeInferShapeFunctor);
REGISTER_OPERATOR(conv3d_transpose_grad, ops::ConvTransposeOpGrad,
Conv3dTranposeGradInferShapeFunctor);
// depthwise conv2d_transpose
DECLARE_INFER_SHAPE_FUNCTOR(depthwise_conv2d_transpose,
DepthWiseConv2dTranposeInferShapeFunctor,
PD_INFER_META(phi::ConvTransposeInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(depthwise_conv2d_transpose_grad,
DepthWiseConv2dTranposeGradInferShapeFunctor,
PD_INFER_META(phi::ConvTransposeGradInferMeta));
REGISTER_OPERATOR(depthwise_conv2d_transpose, ops::ConvTransposeOp,
ops::Conv2DTransposeOpMaker,
ops::ConvTransposeGradOpMaker<paddle::framework::OpDesc>,
ops::ConvTransposeGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(depthwise_conv2d_transpose_grad, ops::ConvTransposeOpGrad);
REGISTER_OP_CPU_KERNEL(
depthwise_conv2d_transpose,
ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
depthwise_conv2d_transpose_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext,
double>);
ops::ConvTransposeGradOpMaker<paddle::imperative::OpBase>,
DepthWiseConv2dTranposeInferShapeFunctor);
REGISTER_OPERATOR(depthwise_conv2d_transpose_grad, ops::ConvTransposeOpGrad,
DepthWiseConv2dTranposeGradInferShapeFunctor);
REGISTER_OP_VERSION(conv_transpose)
.AddCheckpoint(
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/conv_transpose_op.h"
#include "paddle/phi/kernels/gpu/depthwise_conv.h"
namespace ops = paddle::operators;
using CUDA = paddle::platform::CUDADeviceContext;
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using DDim = framework::DDim;
template <typename DeviceContext, typename T>
class DepthwiseConvTransposeKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const std::string data_layout_str =
context.Attr<std::string>("data_format");
const framework::DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
const Tensor* input = context.Input<Tensor>("Input");
Tensor filter = *context.Input<Tensor>("Filter");
Tensor* output = context.Output<Tensor>("Output");
output->mutable_data<T>(context.GetPlace());
int groups = context.Attr<int>("groups");
PADDLE_ENFORCE_EQ(
groups, filter.dims()[0],
platform::errors::InvalidArgument(
"groups should be error to the 1st dimension of filter. But "
"received groups is %d and filter dimension[0] is %d",
groups, filter.dims()[0]));
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
std::string padding_algorithm =
context.Attr<std::string>("padding_algorithm");
for (auto v : dilations) {
PADDLE_ENFORCE_EQ(v, 1, platform::errors::InvalidArgument(
"dilations should be 1 in depthwise conv. "
"But received dilations is %d",
v));
}
auto in_dims = input->dims();
auto filter_dims = filter.dims();
framework::DDim in_data_dims;
if (data_layout != framework::DataLayout::kNHWC) {
in_data_dims = phi::slice_ddim(in_dims, 2, in_dims.size());
} else {
in_data_dims = phi::slice_ddim(in_dims, 1, in_dims.size() - 1);
}
framework::DDim filter_data_dims =
phi::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);
output->mutable_data<T>(context.GetPlace());
auto& dev_ctx = context.template device_context<DeviceContext>();
phi::funcs::SetConstant<DeviceContext, T> set_zero;
set_zero(dev_ctx, output, static_cast<T>(0));
math::DepthwiseConvInputGradFunctor<phi::GPUContext, T>
depthwiseConvInputGrad;
depthwiseConvInputGrad(
static_cast<const typename framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
*output, filter, *input, strides,
std::vector<int>{paddings[0], paddings[2], paddings[1], paddings[3]},
dilations, output, data_layout);
}
};
template <typename DeviceContext, typename T>
class DepthwiseConvTransposeGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const std::string data_layout_str =
context.Attr<std::string>("data_format");
const framework::DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
const Tensor* input = context.Input<Tensor>("Input");
const Tensor* output_grad =
context.Input<Tensor>(framework::GradVarName("Output"));
Tensor* input_grad =
context.Output<Tensor>(framework::GradVarName("Input"));
Tensor* filter_grad =
context.Output<Tensor>(framework::GradVarName("Filter"));
Tensor filter = *context.Input<Tensor>("Filter");
if (!input_grad && !filter_grad) return;
auto& dev_ctx = context.template device_context<DeviceContext>();
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
std::string padding_algorithm =
context.Attr<std::string>("padding_algorithm");
auto in_dims = input->dims();
auto filter_dims = filter.dims();
framework::DDim in_data_dims;
if (data_layout != framework::DataLayout::kNHWC) {
in_data_dims = phi::slice_ddim(in_dims, 2, in_dims.size());
} else {
in_data_dims = phi::slice_ddim(in_dims, 1, in_dims.size() - 1);
}
framework::DDim filter_data_dims =
phi::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);
if (input_grad) {
math::DepthwiseConvFunctor<phi::GPUContext, T> depthwiseConv;
depthwiseConv(
static_cast<const typename framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
*output_grad, filter, strides,
std::vector<int>{paddings[0], paddings[2], paddings[1], paddings[3]},
dilations, input_grad, data_layout);
}
if (filter_grad) {
phi::funcs::SetConstant<DeviceContext, T> set_zero;
filter_grad->mutable_data<T>(context.GetPlace());
set_zero(dev_ctx, filter_grad, static_cast<T>(0));
math::DepthwiseConvFilterGradFunctor<phi::GPUContext, T>
depthwiseConvFilterGrad;
depthwiseConvFilterGrad(
static_cast<const typename framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
*output_grad, *input, strides,
std::vector<int>{paddings[0], paddings[2], paddings[1], paddings[3]},
dilations, filter_grad, data_layout);
}
}
};
} // namespace operators
} // namespace paddle
// conv2d
REGISTER_OP_CUDA_KERNEL(conv2d_transpose,
ops::GemmConvTransposeKernel<CUDA, float>,
ops::GemmConvTransposeKernel<CUDA, double>);
REGISTER_OP_CUDA_KERNEL(conv2d_transpose_grad,
ops::GemmConvTransposeGradKernel<CUDA, float>,
ops::GemmConvTransposeGradKernel<CUDA, double>);
REGISTER_OP_CUDA_KERNEL(conv2d_transpose_grad_grad,
ops::GemmConvTransposeGradKernel<CUDA, float>,
ops::GemmConvTransposeGradKernel<CUDA, double>);
// conv3d
REGISTER_OP_CUDA_KERNEL(conv3d_transpose,
ops::GemmConvTransposeKernel<CUDA, float>,
ops::GemmConvTransposeKernel<CUDA, double>);
REGISTER_OP_CUDA_KERNEL(conv3d_transpose_grad,
ops::GemmConvTransposeGradKernel<CUDA, float>,
ops::GemmConvTransposeGradKernel<CUDA, double>);
// depthwise conv2d
REGISTER_OP_CUDA_KERNEL(depthwise_conv2d_transpose,
ops::DepthwiseConvTransposeKernel<CUDA, float>,
ops::DepthwiseConvTransposeKernel<CUDA, double>);
REGISTER_OP_CUDA_KERNEL(depthwise_conv2d_transpose_grad,
ops::DepthwiseConvTransposeGradKernel<CUDA, float>,
ops::DepthwiseConvTransposeGradKernel<CUDA, double>);
......@@ -13,11 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/conv_transpose_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
#include "paddle/phi/kernels/cpu/conv_util.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using NPUDeviceContext = platform::NPUDeviceContext;
template <typename T>
......@@ -55,8 +59,8 @@ class Conv2DTransposeNPUKernel : public framework::OpKernel<T> {
filter_data_dims = phi::slice_ddim(filter_dims, 2, in_dims.size());
std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&padding, &dilation, padding_algorithm,
in_data_dims, stride, ksize);
phi::UpdatePaddingAndDilation(&padding, &dilation, padding_algorithm,
in_data_dims, stride, ksize);
// construct NPU attr
std::vector<int> strides(4, 1);
......@@ -137,8 +141,8 @@ class Conv2DTransposeGradNPUKernel : public framework::OpKernel<T> {
framework::DDim filter_data_dims =
phi::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);
phi::UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);
std::vector<int> strides_vec(4, 1);
std::vector<int> dilations_vec(4, 1);
......
......@@ -8,15 +8,22 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/conv_transpose_op.h"
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/phi/kernels/cpu/conv_util.h"
#ifdef PADDLE_WITH_XPU
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
// target_len == 2 || target_len == 4
inline std::vector<int> vector_extend(const std::vector<int>& src,
int target_len) {
......@@ -61,8 +68,8 @@ class Conv2DTransposeXPUKernel : public framework::OpKernel<T> {
framework::DDim filter_data_dims =
phi::slice_ddim(filter.dims(), 2, filter.dims().size());
std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);
phi::UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);
const int batch_size = static_cast<int>(input->dims()[0]);
const int img_yc = static_cast<int>(input->dims()[1]);
......@@ -135,8 +142,8 @@ class Conv2DTransposeGradXPUKernel : public framework::OpKernel<T> {
framework::DDim filter_data_dims =
phi::slice_ddim(filter.dims(), 2, filter.dims().size());
std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);
phi::UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);
const int batch_size = static_cast<int>(input->dims()[0]);
const int img_yc = static_cast<int>(input->dims()[1]);
......
......@@ -243,8 +243,6 @@ class ConcatFunctor<platform::MLUDeviceContext, T> {
const int axis_t = axis;
const int ins_size_t = ins_size;
auto place = context.GetPlace();
output->mutable_data<T>(place);
// mlu should do sth
// init ins tensors
......@@ -295,7 +293,6 @@ class SplitFunctor<platform::MLUDeviceContext, T> {
std::vector<cnnlTensorDescriptor_t> desc_vector;
for (size_t i = 0; i < out_size; i++) {
(*outputs)[i]->Resize(outs_dims[i]);
(*outputs)[i]->mutable_data<T>(context.GetPlace());
output_descs.emplace_back(
MLUCnnlTensorDesc(*(*outputs)[i], CNNL_LAYOUT_ARRAY,
ToCnnlDataType((*outputs)[i]->dtype())));
......
......@@ -12,9 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_ops/frobenius_norm_op.h"
#include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace framework {
......@@ -56,22 +59,12 @@ class FrobeniusNormOpMaker : public ops::ReduceOpMaker {
virtual std::string GetOpType() const { return "Reduce frobenius_norm"; }
};
DECLARE_INFER_SHAPE_FUNCTOR(frobenius_norm, FrobeniusNormInferShapeFunctor,
PD_INFER_META(phi::ReduceInferMetaBase));
REGISTER_OPERATOR(frobenius_norm, ops::ReduceOp, FrobeniusNormOpMaker,
ops::FrobeniusNormOpGradMaker<paddle::framework::OpDesc>,
ops::FrobeniusNormOpGradMaker<paddle::imperative::OpBase>);
ops::FrobeniusNormOpGradMaker<paddle::imperative::OpBase>,
FrobeniusNormInferShapeFunctor);
REGISTER_OPERATOR(frobenius_norm_grad, ops::ReduceGradOp);
REGISTER_OP_CPU_KERNEL(frobenius_norm,
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
float, ops::FrobeniusNormFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
double, ops::FrobeniusNormFunctor>);
template <typename T>
using CPUFrobeniusNormGradKernel =
ops::FrobeniusNormGradKernel<paddle::platform::CPUDeviceContext, T,
ops::FrobeniusNormGradFunctor>;
REGISTER_OP_CPU_KERNEL(frobenius_norm_grad, CPUFrobeniusNormGradKernel<float>,
CPUFrobeniusNormGradKernel<double>);
......@@ -117,7 +117,7 @@ endif()
cc_library(cudnn_workspace_helper SRCS cudnn_workspace_helper.cc DEPS boost)
# seperate init from device_context to avoid cycle dependencies
cc_library(init SRCS init.cc DEPS device_context custom_kernel)
cc_library(init SRCS init.cc DEPS device_context custom_kernel context_pool)
# memcpy depends on device_context, here add deps individually for
# avoiding cycle dependencies
......
......@@ -13,7 +13,7 @@ IF(WITH_IPU)
"ipu_device.cc"
)
cc_library(ipu_backend SRCS ${IPU_BACKEND_SRC} DEPS popart-only graph graph_helper)
cc_library(ipu_backend SRCS ${IPU_BACKEND_SRC} DEPS popart-only graph graph_helper popdist)
cc_library(ipu_info SRCS ${IPU_INFO_SRC} DEPS popart-only enforce)
add_library(paddle_ipu SHARED ${PADDLE_IPU_SRC})
add_dependencies(paddle_ipu ipu_backend)
......
......@@ -32,6 +32,7 @@ IpuBackend* IpuBackend::GetInstance() {
IpuBackend::IpuBackend() {
compiler_ = std::make_unique<Compiler>();
executor_ = std::make_unique<Executor>();
timer_ = std::make_unique<platform::Timer>();
}
IpuBackend::~IpuBackend() {
......@@ -43,6 +44,7 @@ void IpuBackend::Compile(Graph* graph,
const std::vector<std::string>& feed_list,
const std::vector<std::string>& fetch_list) {
VLOG(10) << "enter IpuBackend::Compile";
is_compiled_ = false;
compiler_->Prepare(graph);
compiler_->InitInputs(feed_list);
compiler_->LowerConstants(scope_);
......@@ -52,31 +54,25 @@ void IpuBackend::Compile(Graph* graph,
if (ipu_strategy_->is_training) {
compiler_->LowerOptimizer(scope_);
}
if (!ipu_strategy_->onnx_dump_path.empty()) {
SaveModelProto(ipu_strategy_->onnx_dump_path);
}
executor_->SetCompilerResources(compiler_->GetResources());
executor_->Prepare(compiler_->GetModelProto());
is_compiled_ = true;
// when call compile, means a new graph
is_prepared_ = false;
VLOG(10) << "leave IpuBackend::Compile";
}
void IpuBackend::Run(const std::vector<const Tensor*>& inputs,
const std::vector<Tensor*>& outputs,
const framework::ExecutionContext& ctx) {
Prepare();
timer_->Start();
executor_->Run(inputs, outputs, ctx);
timer_->Pause();
VLOG(10) << "[IPU Run]: " << timer_->ElapsedMS() << " (ms)";
}
void IpuBackend::Prepare() {
if (!is_prepared_) {
executor_->Prepare(compiler_->GetModelProto());
timer_.reset(new platform::Timer());
is_prepared_ = true;
}
}
void IpuBackend::WeightsToHost() { executor_->WeightsToHost(); }
void IpuBackend::Detach() { executor_->Detach(); }
......@@ -101,12 +97,10 @@ void IpuBackend::SetIpuStrategy(const IpuStrategy& strategy) {
}
void IpuBackend::SaveModelProto(const std::string& path) {
if (ipu_strategy_->is_training && is_prepared_) {
if (ipu_strategy_->is_training && is_compiled_) {
executor_->SaveModelToHost(path);
} else if (is_compiled_) {
compiler_->SaveModelProtoNoCheck(path);
} else {
LOG(WARNING) << "Model is empty";
compiler_->SaveModelProtoNoCheck(path);
}
}
......
......@@ -60,6 +60,9 @@ class IpuBackend {
const std::vector<Tensor *> &outputs,
const framework::ExecutionContext &ctx);
// Sync weights from IPU while training
void WeightsToHost();
// detach IPU manually
void Detach();
......@@ -76,22 +79,17 @@ class IpuBackend {
void SaveModelProto(const std::string &path);
private:
void Prepare();
private:
std::unique_ptr<Compiler> compiler_;
std::unique_ptr<Executor> executor_;
bool is_compiled_ = false;
bool is_prepared_ = false;
// not own
const Scope *scope_ = nullptr;
const IpuStrategy *ipu_strategy_ = nullptr;
private:
// time record for IpuBackend::Run
// own
std::unique_ptr<Compiler> compiler_;
std::unique_ptr<Executor> executor_;
std::unique_ptr<platform::Timer> timer_;
bool is_compiled_ = false;
DISABLE_COPY_AND_ASSIGN(IpuBackend);
};
......
......@@ -18,6 +18,7 @@
#include <popart/adaptive.hpp>
#include <popart/optimizer.hpp>
#include <popart/sgd.hpp>
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/device/ipu/ipu_utils.h"
......@@ -25,13 +26,20 @@ namespace paddle {
namespace platform {
namespace ipu {
popart::AdamMode AdamModeFromStr(const std::string& str) {
popart::AdamMode AdamModeFromStr(const std::string& str,
const bool& use_no_bias_optimizer) {
if (str == "adam") {
return popart::AdamMode::Adam;
if (!use_no_bias_optimizer)
return popart::AdamMode::Adam;
else
return popart::AdamMode::AdamNoBias;
} else if (str == "adamax") {
return popart::AdamMode::AdaMax;
} else if (str == "lamb") {
return popart::AdamMode::Lamb;
if (!use_no_bias_optimizer)
return popart::AdamMode::Lamb;
else
return popart::AdamMode::LambNoBias;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Uknown AdamMode: %s, AdamMode must be one of these values: adam, "
......@@ -70,6 +78,17 @@ popart::WeightDecayMode WeightDecayModeFromStr(const std::string& str) {
}
}
popart::DataType DataTypeFromStr(const std::string& str) {
if (str == "FLOAT") {
return popart::DataType::FLOAT;
} else if (str == "FLOAT16") {
return popart::DataType::FLOAT16;
} else {
PADDLE_THROW(
platform::errors::Unimplemented("Unsupported DataType: %s", str));
}
}
template <typename T>
T GetAttrAllowNull(std::string attr, OpDesc* op_desc) {
if (op_desc->HasAttr(attr)) {
......@@ -122,6 +141,17 @@ void Compiler::Prepare(const Graph* graph) {
builder_ = popart::Builder::create();
resources_ = std::make_unique<CompilerResources>();
graph_helper_ = std::make_unique<GraphHelper>(graph);
// Set the flag of set_amp_for_all_
for (auto* node : graph_helper_->sorted_ops) {
auto* op_desc = node->Op();
auto op_type = op_desc->Type();
if (op_type == "popart_matmul") {
if (op_desc->HasAttr(sAvailMemAttribute)) {
set_amp_for_all_ = false;
return;
}
}
}
}
void Compiler::RegisterOpFunc() {
......@@ -155,7 +185,9 @@ void Compiler::RegisterOpFunc() {
auto debug_context = BuildDebugContext(op_desc); \
auto aiGraphcoreOpset = builder_->aiGraphcoreOpset1(); \
auto aiOnnxOpset = builder_->aiOnnxOpset11(); \
PushNameScope(op_desc); \
auto output_ids = OnnxImpl(inputs Args, debug_context); \
PopNameScope(op_desc); \
SetIpuIndexStage(output_ids, op_desc); \
SetAMPAttributes(output_ids, op_desc); \
SetSerializeAttributes(output_ids, op_desc); \
......@@ -241,7 +273,9 @@ void Compiler::LowerConstants(const Scope* scope) {
popart::TensorInfo tensor_info(PdDataType2PopartType(tensor->dtype()),
shape);
const_data.reset(new popart::ConstVoidData(tensor->data(), tensor_info));
PushNameScope(op_desc);
popart::TensorId result = builder_->aiOnnxOpset11().constant(*const_data);
PopNameScope(op_desc);
SetIpuIndexStage(result, op_desc);
resources_->tensors.emplace(tensor_name, result);
}
......@@ -261,6 +295,10 @@ void Compiler::LowerWeights(const Scope* scope) {
VLOG(10) << "found existed one, skip lowering Weight: " << var_name;
continue;
}
if (var_name.rfind("learning_rate", 0) == 0) {
VLOG(10) << "skip learning_rate_var: " << var_name;
continue;
}
VLOG(10) << "lowering weight: " << var_name;
auto var = scope->FindVar(var_name);
......@@ -273,10 +311,15 @@ void Compiler::LowerWeights(const Scope* scope) {
}
popart::TensorInfo tensor_info(dtype, shape);
popart::ConstVoidData const_data{tensor.data(), tensor_info};
popart::TensorId result =
builder_->addInitializedInputTensor(const_data, var_name);
resources_->tensors.emplace(var_name, result);
resources_->weights.push_back(result);
if (!node->outputs.empty()) {
auto op_node = node->outputs[0];
PushNameScope(op_node->Op());
popart::TensorId result =
builder_->addInitializedInputTensor(const_data, var_name);
PopNameScope(op_node->Op());
resources_->tensors.emplace(var_name, result);
resources_->weights.push_back(var_name);
}
}
}
}
......@@ -298,7 +341,10 @@ void Compiler::LowerBody() {
} else if (op_type == "popart_checkpointoutput") {
auto inputs = GetOpInputs(op_desc);
auto outputs = GetOpOutputs(op_desc);
PushNameScope(op_desc);
auto output_ids = builder_->checkpointOutput(inputs);
PopNameScope(op_desc);
SetIpuIndexStage(output_ids, op_desc);
InsertTensors(outputs, output_ids);
} else if (op_type == "popart_custom_op") {
auto inputs = GetOpInputs(op_desc);
......@@ -313,9 +359,11 @@ void Compiler::LowerBody() {
BOOST_GET_CONST(std::string, op_desc->GetAttr("__op_type"));
VLOG(10) << "Build graph from custom op: " << __op_type;
auto it = custom_ops_.find(__op_type);
PushNameScope(op_desc);
auto output_ids =
builder_->customOp(it->second.popart_op, it->second.popart_op.version,
inputs, outputs.size(), attributes, debug_context);
PopNameScope(op_desc);
SetIpuIndexStage(output_ids, op_desc);
InsertTensors(outputs, output_ids);
} else if (op_type == "popart_printtensor") {
......@@ -325,8 +373,10 @@ void Compiler::LowerBody() {
auto print_gradient =
BOOST_GET_CONST(int64_t, op_desc->GetAttr("print_gradient"));
auto title = BOOST_GET_CONST(std::string, op_desc->GetAttr("title"));
PushNameScope(op_desc);
auto output_ids = builder_->aiGraphcoreOpset1().printtensor(
inputs, print_gradient, debug_context, title);
PopNameScope(op_desc);
SetIpuIndexStage(output_ids, op_desc);
InsertTensors(outputs, output_ids);
} else {
......@@ -367,8 +417,31 @@ void Compiler::LowerOptimizer(const Scope* scope) {
resources_->with_lr_sched = false;
}
VLOG(10) << "Set initial lr: " << resources_->lr;
auto loss_scaling = ipu_strategy_->loss_scaling;
// Get the type of optimizer
auto type = BOOST_GET_CONST(std::string, op_desc->GetAttr("type"));
// Set weight decay by tensor names for Lamb
auto weight_decay_vars = BOOST_GET_CONST(
std::vector<std::string>, op_desc->GetAttr("weight_decay_vars"));
auto weight_decay_values = BOOST_GET_CONST(
std::vector<float>, op_desc->GetAttr("weight_decay_values"));
// Get the maximum permissible value for gradient clipping
std::vector<popart::ClipNormSettings> clip_norm_settings = {};
if (op_desc->HasAttr("clip_norm")) {
auto clip_norm = BOOST_GET_CONST(float, op_desc->GetAttr("clip_norm"));
clip_norm_settings.push_back(
popart::ClipNormSettings::clipAllWeights(clip_norm));
VLOG(10) << "Set the global gradient clipping with the maximum "
"permissible value: "
<< clip_norm;
}
// Values from ipu_strategy
auto loss_scaling = ipu_strategy_->loss_scaling;
auto accl1_type = DataTypeFromStr(ipu_strategy_->accl1_type);
auto accl2_type = DataTypeFromStr(ipu_strategy_->accl2_type);
auto accl3_type = DataTypeFromStr(ipu_strategy_->accl3_type);
if (type == "sgd") {
auto weight_decay =
BOOST_GET_CONST(float, op_desc->GetAttr("weight_decay"));
......@@ -376,12 +449,18 @@ void Compiler::LowerOptimizer(const Scope* scope) {
resources_->optimizer_fn = [=](float lr) {
return std::make_unique<popart::SGD>(
popart::OptimizerValue(lr, false),
popart::OptimizerValue(weight_decay, true),
popart::OptimizerValue(weight_decay, false),
popart::OptimizerValue(momentum, true),
popart::SGD::getUnsetDampening(),
popart::SGD::getUnsetVelocityScaling(),
popart::OptimizerValue(loss_scaling, true));
popart::OptimizerValue(loss_scaling, true), clip_norm_settings);
};
resources_->eval_optimizer = std::make_unique<popart::SGD>(
popart::OptimizerValue(0.0, false),
popart::OptimizerValue(0.0, false),
popart::OptimizerValue(0.0, true), popart::SGD::getUnsetDampening(),
popart::SGD::getUnsetVelocityScaling(),
popart::OptimizerValue(loss_scaling, true), clip_norm_settings);
} else if (type == "adam") {
auto weight_decay =
BOOST_GET_CONST(float, op_desc->GetAttr("weight_decay"));
......@@ -392,22 +471,79 @@ void Compiler::LowerOptimizer(const Scope* scope) {
VLOG(10) << "set max_weight_norm: " << mwn;
auto adam_mode_ =
BOOST_GET_CONST(std::string, op_desc->GetAttr("adam_mode"));
auto adam_mode = AdamModeFromStr(adam_mode_);
auto weight_decay_mode_ =
BOOST_GET_CONST(std::string, op_desc->GetAttr("weight_decay_mode"));
auto adam_mode =
AdamModeFromStr(adam_mode_, ipu_strategy_->use_no_bias_optimizer);
auto weight_decay_mode_ = ipu_strategy_->weight_decay_mode;
if (weight_decay_mode_.empty()) {
weight_decay_mode_ = BOOST_GET_CONST(
std::string, op_desc->GetAttr("weight_decay_mode"));
}
auto weight_decay_mode = WeightDecayModeFromStr(weight_decay_mode_);
resources_->optimizer_fn = [=](float lr) {
return std::make_unique<popart::Adam>(
popart::OptimizerValue(lr, false),
popart::OptimizerValue(weight_decay, true),
popart::OptimizerValue(beta1, true),
popart::OptimizerValue(beta2, true),
if (adam_mode == popart::AdamMode::Lamb ||
adam_mode == popart::AdamMode::LambNoBias) {
const std::map<std::string, std::pair<float, bool>>
optimizer_value = {{"defaultLearningRate", {lr, false}},
{"defaultBeta1", {beta1, false}},
{"defaultBeta2", {beta2, false}},
{"defaultEps", {eps, true}},
{"lossScaling", {loss_scaling, true}},
{"defaultMaxWeightNorm", {mwn, true}}};
auto optimizer_instance = std::make_unique<popart::Adam>(
optimizer_value, adam_mode, weight_decay_mode,
popart::DataType::UNDEFINED, accl1_type, accl2_type,
clip_norm_settings);
for (int i = 0; i < weight_decay_vars.size(); i++) {
optimizer_instance->insertSpecific(
weight_decay_vars[i],
{{"weightDecay", {weight_decay_values[i], false}}});
VLOG(10) << "Set Tensor " << weight_decay_vars[i]
<< " weight decay as " << weight_decay_values[i];
}
return optimizer_instance;
} else {
return std::make_unique<popart::Adam>(
popart::OptimizerValue(lr, false),
popart::OptimizerValue(weight_decay, false),
popart::OptimizerValue(beta1, false),
popart::OptimizerValue(beta2, false),
popart::OptimizerValue(eps, true),
popart::OptimizerValue(loss_scaling, true),
popart::OptimizerValue(mwn, true), adam_mode, weight_decay_mode,
popart::DataType::UNDEFINED, accl1_type, accl2_type,
clip_norm_settings);
}
};
if (adam_mode == popart::AdamMode::Lamb ||
adam_mode == popart::AdamMode::LambNoBias) {
const std::map<std::string, std::pair<float, bool>> optimizer_value =
{{"defaultLearningRate", {0.0, false}},
{"defaultBeta1", {beta1, false}},
{"defaultBeta2", {beta2, false}},
{"defaultEps", {eps, true}},
{"lossScaling", {loss_scaling, true}},
{"defaultMaxWeightNorm", {mwn, true}}};
auto eval_optimizer = std::make_unique<popart::Adam>(
optimizer_value, adam_mode, weight_decay_mode,
popart::DataType::UNDEFINED, popart::DataType::FLOAT,
popart::DataType::FLOAT, clip_norm_settings);
for (int i = 0; i < weight_decay_vars.size(); i++) {
eval_optimizer->insertSpecific(weight_decay_vars[i],
{{"weightDecay", {0.0, false}}});
}
resources_->eval_optimizer = std::move(eval_optimizer);
} else {
resources_->eval_optimizer = std::make_unique<popart::Adam>(
popart::OptimizerValue(0.0, false),
popart::OptimizerValue(0.0, false),
popart::OptimizerValue(beta1, false),
popart::OptimizerValue(beta2, false),
popart::OptimizerValue(eps, true),
popart::OptimizerValue(loss_scaling, true),
popart::OptimizerValue(mwn, true), adam_mode, weight_decay_mode,
popart::DataType::UNDEFINED, popart::DataType::FLOAT,
popart::DataType::FLOAT);
};
popart::DataType::FLOAT, clip_norm_settings);
}
} else if (type == "adaptive") {
auto alpha = BOOST_GET_CONST(float, op_desc->GetAttr("alpha"));
auto momentum = BOOST_GET_CONST(float, op_desc->GetAttr("momentum"));
......@@ -417,21 +553,33 @@ void Compiler::LowerOptimizer(const Scope* scope) {
auto adaptive_mode_ =
BOOST_GET_CONST(std::string, op_desc->GetAttr("adaptive_mode"));
auto adaptive_mode = AdaptiveModeFromStr(adaptive_mode_);
auto weight_decay_mode_ =
BOOST_GET_CONST(std::string, op_desc->GetAttr("weight_decay_mode"));
auto weight_decay_mode_ = ipu_strategy_->weight_decay_mode;
if (weight_decay_mode_.empty()) {
weight_decay_mode_ = BOOST_GET_CONST(
std::string, op_desc->GetAttr("weight_decay_mode"));
}
auto weight_decay_mode = WeightDecayModeFromStr(weight_decay_mode_);
resources_->optimizer_fn = [=](float lr) {
return std::make_unique<popart::Adaptive>(
popart::OptimizerValue(lr, false),
popart::OptimizerValue(weight_decay, true),
popart::OptimizerValue(weight_decay, false),
popart::OptimizerValue(alpha, true),
popart::OptimizerValue(momentum, true),
popart::OptimizerValue(eps, true),
popart::OptimizerValue(loss_scaling, true), adaptive_mode,
weight_decay_mode, popart::DataType::UNDEFINED,
popart::DataType::FLOAT, popart::DataType::FLOAT,
popart::DataType::FLOAT);
weight_decay_mode, popart::DataType::UNDEFINED, accl1_type,
accl2_type, accl3_type);
};
resources_->eval_optimizer = std::make_unique<popart::Adaptive>(
popart::OptimizerValue(0.0, false),
popart::OptimizerValue(0.0, false),
popart::OptimizerValue(alpha, true),
popart::OptimizerValue(momentum, true),
popart::OptimizerValue(eps, true),
popart::OptimizerValue(loss_scaling, true), adaptive_mode,
weight_decay_mode, popart::DataType::UNDEFINED,
popart::DataType::FLOAT, popart::DataType::FLOAT,
popart::DataType::UNDEFINED);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"optimizer %s is not implemented", type));
......@@ -510,9 +658,32 @@ void Compiler::SetAMPAttributes(const std::string& tensor_id,
const OpDesc* op_desc) {
VLOG(10) << "enter Compiler::SetAMPAttributes";
if (op_desc->Type() == "popart_matmul") {
auto amp = ipu_strategy_->available_memory_proportion;
if (amp > 0.0f && amp <= 1.0) {
builder_->setAvailableMemoryProportion(tensor_id, amp);
if (set_amp_for_all_) {
auto amp = ipu_strategy_->available_memory_proportion;
if (amp < 0.0f || amp > 1.0) {
PADDLE_THROW(platform::errors::InvalidArgument(
"AvailableMemoryProportion %f is invalid, which should be set 0 <= "
"amp <= 1",
amp));
}
if (amp > 0.0f) {
builder_->setAvailableMemoryProportion(tensor_id, amp);
}
} else {
if (op_desc->HasAttr(sAvailMemAttribute)) {
auto amp = BOOST_GET_CONST(float, op_desc->GetAttr(sAvailMemAttribute));
if (amp < 0.0f || amp > 1.0) {
PADDLE_THROW(platform::errors::InvalidArgument(
"AvailableMemoryProportion %f is invalid, which should be set 0 "
"<= amp <= 1",
amp));
}
if (amp > 0.0f) {
builder_->setAvailableMemoryProportion(tensor_id, amp);
VLOG(10) << "set available_memory_proportion for tensor: "
<< tensor_id << " as " << amp;
}
}
}
}
VLOG(10) << "leave Compiler::SetAMPAttributes";
......@@ -602,6 +773,29 @@ popart::DebugContext Compiler::BuildDebugContext(const OpDesc* op) {
return popart::DebugContext(op_identify_id);
}
void Compiler::PushNameScope(const OpDesc* op) {
auto op_namescope = BOOST_GET_CONST(std::string, op->GetAttr(sOpNamescope));
if (op_namescope == "/") {
return;
}
if (!op_namescope.empty()) {
op_namescope.pop_back();
}
if (!op_namescope.empty()) {
op_namescope.erase(op_namescope.begin());
}
VLOG(10) << "name_scope is: " << op_namescope;
builder_->pushNameScope(op_namescope);
}
void Compiler::PopNameScope(const OpDesc* op) {
auto op_namescope = BOOST_GET_CONST(std::string, op->GetAttr(sOpNamescope));
if (op_namescope == "/") {
return;
}
builder_->popNameScope();
}
} // namespace ipu
} // namespace platform
} // namespace paddle
......@@ -50,6 +50,8 @@ struct CompilerResources {
using OptimizerFn =
std::function<std::unique_ptr<popart::Optimizer>(float lr)>;
OptimizerFn optimizer_fn;
// The eval mode of optimizer in training
std::unique_ptr<popart::Optimizer> eval_optimizer;
public:
popart::Optimizer *Optimizer() { return optimizer.get(); }
......@@ -110,6 +112,7 @@ class Compiler {
void RegisterOpFunc();
std::vector<std::string> GetOpInputs(const OpDesc *op);
const std::vector<std::string> &GetOpOutputs(const OpDesc *op);
const std::string GetNameScope(const OpDesc *op);
popart::DebugContext BuildDebugContext(const OpDesc *op);
void InsertTensors(const std::vector<std::string> &output_names,
......@@ -126,6 +129,8 @@ class Compiler {
const OpDesc *op_desc);
void SetSerializeAttributes(const std::string &tensor_id,
const OpDesc *op_desc);
void PushNameScope(const OpDesc *op);
void PopNameScope(const OpDesc *op);
private:
std::unique_ptr<popart::Builder> builder_;
......@@ -137,6 +142,14 @@ class Compiler {
const IpuStrategy *ipu_strategy_ = nullptr;
std::map<std::string, IpuCustomOpIdentifier> custom_ops_;
// Used to choose the way to set amp for Ops
// If anyone op has the attr sAvailMemAttribute, the
// available_memory_proportion from ipu_strategy
// will be ignored and the Ops are set by their own sAvailMemAttribute. Else,
// all relevant Ops will be set by
// the available_memory_proportion from ipu_strategy.
bool set_amp_for_all_ = true;
};
} // namespace ipu
......
......@@ -64,15 +64,10 @@ void Executor::Prepare(const std::string &proto) {
WeightsFromPaddle();
VLOG(10) << "Copy weights from paddle to popart...done";
VLOG(10) << "Copy weights from host to device...";
session_->weightsFromHost();
VLOG(10) << "Copy weights from host to device...done";
if (ipu_strategy_->save_init_onnx) {
session_->modelToHost("test_init.onnx");
if (ipu_strategy_->random_seed != std::numeric_limits<std::uint64_t>::max()) {
VLOG(10) << "Setting random seed to: " << ipu_strategy_->random_seed;
session_->setRandomSeed(ipu_strategy_->random_seed);
}
// init run step
step_ = 0;
}
void Executor::Run(const std::vector<const Tensor *> &inputs,
......@@ -120,11 +115,17 @@ void Executor::Run(const std::vector<const Tensor *> &inputs,
VLOG(10) << "Prepared inputs/anchors";
if (ipu_strategy_->is_training && compiler_resources_->with_lr_sched) {
VLOG(10) << "Update learning_rate";
auto new_lr =
GetSingleVarFromScope<float>(scope_, compiler_resources_->lr_var);
VLOG(10) << "New Lr: " << new_lr;
auto *optimizer = compiler_resources_->UpdateOptimizer(new_lr);
popart::Optimizer *optimizer;
if (ipu_strategy_->runtime_options.enable_eval) {
VLOG(10) << "Switch optimizer to eval mode";
optimizer = compiler_resources_->eval_optimizer.get();
} else {
VLOG(10) << "Update learning_rate";
auto new_lr =
GetSingleVarFromScope<float>(scope_, compiler_resources_->lr_var);
VLOG(10) << "New Lr: " << new_lr;
optimizer = compiler_resources_->UpdateOptimizer(new_lr);
}
auto *session = dynamic_cast<popart::TrainingSession *>(session_.get());
session->updateOptimizerFromHost(optimizer);
}
......@@ -133,15 +134,13 @@ void Executor::Run(const std::vector<const Tensor *> &inputs,
VLOG(10) << "Running...";
session_->run(stepio);
VLOG(10) << "Running...done";
}
step_++;
if (ipu_strategy_->is_training &&
step_ % ipu_strategy_->save_per_n_step == 0) {
session_->weightsToHost();
void Executor::WeightsToHost() {
if (ipu_strategy_->is_training && session_) {
WeightsToPaddle();
if (ipu_strategy_->save_onnx_checkpoint) {
session_->modelToHost("test_last" + std::to_string(step_) + ".onnx");
}
} else {
LOG(WARNING) << "For a non-trainning graph, cannot sync weights from IPU.";
}
}
......@@ -153,6 +152,7 @@ void Executor::AcquireDevice() {
}
bool use_ipu_model = GetBoolEnv("POPLAR_IPUMODEL");
bool enable_distribution = ipu_strategy_->enable_distribution;
if (use_ipu_model) {
std::map<std::string, std::string> deviceOpts{
{
......@@ -162,6 +162,16 @@ void Executor::AcquireDevice() {
};
device_ = popart::DeviceManager::createDeviceManager().createIpuModelDevice(
deviceOpts);
} else if (enable_distribution) {
auto ipus_per_replica = ipu_strategy_->num_ipus /
ipu_strategy_->popart_options.replicatedGraphCount;
auto device_id = popdist_get_device(ipus_per_replica);
device_ = popart::DeviceManager::createDeviceManager().acquireDeviceById(
device_id);
PADDLE_ENFORCE_NOT_NULL(
device_, platform::errors::Unavailable(
"Can't attach IPU in distribution, ipu_num = %d.",
RequestIpus(ipu_strategy_->num_ipus)));
} else {
device_ =
popart::DeviceManager::createDeviceManager().acquireAvailableDevice(
......@@ -185,28 +195,29 @@ void Executor::SetWeightsIO() {
auto opt_type = compiler_resources_->optimizer_type;
VLOG(10) << "SetWeightsIO for " << opt_type;
auto pre_post_fix = GetOptPrePostfix(opt_type);
for (const auto &weight_id : compiler_resources_->weights) {
for (const auto &weight_pd : compiler_resources_->weights) {
for (const auto &pair : pre_post_fix) {
// pair.first : popart prefix, pair.second : paddle postfix
auto popart_var_name = pair.first + weight_id;
auto paddle_var_name = weight_id + pair.second;
auto weight_pop = compiler_resources_->tensors[weight_pd];
auto popart_var = pair.first + weight_pop;
auto paddle_var = weight_pd + pair.second;
if (scope_->FindVar(paddle_var_name) == nullptr) {
if (scope_->FindVar(paddle_var) == nullptr) {
continue;
}
if (!session_->hasInfo(popart_var_name)) {
if (!session_->hasInfo(popart_var)) {
continue;
}
auto var = scope_->GetVar(paddle_var_name);
VLOG(10) << "Connect paddle weight: " << paddle_var
<< " with popart weight: " << popart_var;
auto var = scope_->GetVar(paddle_var);
auto data_ptr = var->GetMutable<framework::LoDTensor>()->data();
auto tensor_info = session_->getInfo(popart_var_name);
executor_resources_->weights_io.insert(popart_var_name,
auto tensor_info = session_->getInfo(popart_var);
executor_resources_->weights_io.insert(popart_var,
{data_ptr, tensor_info});
executor_resources_->weights_and_opt_state.emplace_back(
std::make_pair(popart_var_name, paddle_var_name));
std::make_pair(popart_var, paddle_var));
}
}
}
......@@ -284,6 +295,7 @@ void Executor::ConvertWeights(bool align_to_popart) {
void Executor::WeightsFromPaddle() {
ConvertWeights(true);
session_->writeWeights(executor_resources_->weights_io);
session_->weightsFromHost();
}
// |-----------------------------------------------------|
......@@ -297,13 +309,13 @@ void Executor::WeightsFromPaddle() {
// Paddle -> halfToFloat: cast then save to paddle
// Popart -> Paddle: copy from paddle to popart
void Executor::WeightsToPaddle() {
session_->weightsToHost();
session_->readWeights(executor_resources_->weights_io);
ConvertWeights(false);
}
void Executor::SaveModelToHost(const std::string &path) {
if (session_) {
session_->weightsToHost();
WeightsToPaddle();
session_->modelToHost(path);
} else {
......
......@@ -20,6 +20,7 @@ limitations under the License. */
#include <popart/patterns/patterns.hpp>
#include <popart/session.hpp>
#include <popart/tensorinfo.hpp>
#include <popdist/popdist_poplar.hpp>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
......@@ -36,8 +37,7 @@ struct ExecutorResources {
// map<tensor_id, paddle_var_ptr>
popart::WeightsIO weights_io;
// <popart_var, paddle_var> pairs, include weights and optimizer states
std::vector<std::pair<popart::TensorId, popart::TensorId>>
weights_and_opt_state;
std::vector<std::pair<popart::TensorId, std::string>> weights_and_opt_state;
};
class Executor {
......@@ -53,14 +53,12 @@ class Executor {
const std::vector<Tensor *> &outputs,
const framework::ExecutionContext &ctx);
// sync weights from popart to paddle
void WeightsToHost();
// detach IPU
void Detach();
void SetWeightsIO();
void ConvertWeights(bool align_to_popart);
void WeightsFromPaddle();
void WeightsToPaddle();
// Scope
void SetScope(const Scope *scope) { scope_ = scope; }
......@@ -79,6 +77,10 @@ class Executor {
private:
void AcquireDevice();
void SetWeightsIO();
void ConvertWeights(bool);
void WeightsFromPaddle();
void WeightsToPaddle();
private:
// not own
......@@ -92,8 +94,6 @@ class Executor {
std::unique_ptr<popart::Session> session_;
// one OneSession means a graph
std::unique_ptr<ExecutorResources> executor_resources_;
int step_ = 0;
};
} // namespace ipu
......
......@@ -24,6 +24,8 @@ static constexpr const char *sIpuIndexAttr = "ipu_index";
static constexpr const char *sIpuStageAttr = "ipu_stage";
static constexpr const char *sMatmulSerializeFactor = "serialize_factor";
static constexpr const char *sMatmulSerializeMode = "serialize_mode";
static constexpr const char *sAvailMemAttribute = "__available_memory";
static constexpr const char *sOpNamescope = "op_namescope";
static constexpr const char *sOpIdentifyIdAttr = "op_identify_id";
static constexpr const char *sDebugInfoId = "__debug_info_id";
......
......@@ -62,23 +62,40 @@ IpuStrategy::IpuStrategy() {
[&]() { return name; })
ADD_BOOL_OPTION(is_training);
ADD_BOOL_OPTION(save_init_onnx);
ADD_BOOL_OPTION(save_onnx_checkpoint);
ADD_BOOL_OPTION(need_avg_shard);
ADD_BOOL_OPTION(enable_fp16);
ADD_BOOL_OPTION(transfer_cast_op);
ADD_BOOL_OPTION(use_no_bias_optimizer);
ADD_BOOL_OPTION(enable_distribution);
ADD_UINT64_OPTION(num_ipus);
ADD_UINT64_OPTION(batches_per_step);
ADD_UINT64_OPTION(micro_batch_size);
ADD_UINT64_OPTION(save_per_n_step);
ADD_UINT64_OPTION(random_seed);
ADD_DOUBLE_OPTION(available_memory_proportion);
ADD_DOUBLE_OPTION(loss_scaling);
ADD_DOUBLE_OPTION(max_weight_norm);
ADD_STRING_OPTION(accl1_type);
ADD_STRING_OPTION(accl2_type);
ADD_STRING_OPTION(accl3_type);
ADD_STRING_OPTION(onnx_dump_path);
ADD_STRING_OPTION(weight_decay_mode);
#undef ADD_STRING_OPTION
#undef ADD_DOUBLE_OPTION
#undef ADD_UINT64_OPTION
#undef ADD_BOOL_OPTION
#define ADD_RUNTIME_BOOL_OPTION(name, aliased_name) \
RegisterSetter(bool_options, #name, \
[&](bool value) { runtime_options.aliased_name = value; }); \
RegisterGetter(options_getter, options_type, #name, "bool", [&]() { \
return std::to_string(runtime_options.aliased_name); \
})
ADD_RUNTIME_BOOL_OPTION(runtime_options.enable_eval, enable_eval);
#undef ADD_RUNTIME_BOOL_OPTION
#define ADD_POPART_ENUM_OPTION_ALIAS(name, aliased_name, EnumType) \
RegisterSetter(uint64_options, #name, [&](std::uint64_t value) { \
PADDLE_ENFORCE_LT( \
......@@ -171,6 +188,7 @@ IpuStrategy::IpuStrategy() {
ADD_POPART_UINT64_OPTION_ALIAS(merge_var_update_mem_threshold,
mergeVarUpdateMemThreshold);
ADD_POPART_UINT64_OPTION_ALIAS(loose_threshold_at_peak, looseThresholdAtPeak);
ADD_POPART_UINT64_OPTION_ALIAS(replicated_graph_count, replicatedGraphCount);
ADD_POPART_UINT64_OPTION_ALIAS(accumulation_factor, accumulationFactor);
ADD_POPART_UINT64_OPTION_ALIAS(swap_limit_scheduler, swapLimitScheduler);
ADD_POPART_UINT64_OPTION_ALIAS(global_replication_factor,
......@@ -462,12 +480,30 @@ void IpuStrategy::SetTensorLocation(const std::string& tensor,
} else if (opt == "use_io_tiles_to_store") {
settings->location.storageTileSet =
value > 0 ? popart::TileSet::IO : popart::TileSet::Compute;
} else if (opt == "sharding_domain_with_all") {
settings->location.shardingDomain =
popart::CommGroup(popart::CommGroupType::All, value);
} else if (opt == "sharding_domain_with_consecutive") {
settings->location.shardingDomain =
popart::CommGroup(popart::CommGroupType::Consecutive, value);
} else if (opt == "sharding_domain_with_orthogonal") {
settings->location.shardingDomain =
popart::CommGroup(popart::CommGroupType::Orthogonal, value);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unknown option ' %s' for tensor location: %s", opt, tensor));
}
}
void IpuStrategy::SetAccumulateOuterFragmentSettings(
const std::uint64_t& schedule, const std::vector<int>& values) {
VLOG(10) << "SetAccumulateOuterFragmentSettings schedule:" << schedule;
auto schedule_ =
static_cast<popart::AccumulateOuterFragmentSchedule>(schedule);
popart_options.accumulateOuterFragmentSettings =
popart::AccumulateOuterFragmentSettings(schedule_, values);
}
void IpuStrategy::AddCustomOp(const std::string& paddle_op,
const std::string& popart_op,
const std::string& domain, int version) {
......
......@@ -34,15 +34,36 @@ Node *logical_not_handler(Graph *graph, Node *node) {
{GetOutputVarNode("Out", node)}, {});
}
Node *logical_or_handler(Graph *graph, Node *node) {
return CreateBaseOp(graph, node, "popart_logical_or",
{GetInputVarNode("X", node), GetInputVarNode("Y", node)},
{GetOutputVarNode("Out", node)}, {});
}
Node *logical_and_handler(Graph *graph, Node *node) {
return CreateBaseOp(graph, node, "popart_logical_and",
{GetInputVarNode("X", node), GetInputVarNode("Y", node)},
{GetOutputVarNode("Out", node)}, {});
}
Node *greater_than_handler(Graph *graph, Node *node) {
return CreateBaseOp(graph, node, "popart_greater",
{GetInputVarNode("X", node), GetInputVarNode("Y", node)},
{GetOutputVarNode("Out", node)}, {});
}
Node *less_than_handler(Graph *graph, Node *node) {
return CreateBaseOp(graph, node, "popart_less",
{GetInputVarNode("X", node), GetInputVarNode("Y", node)},
{GetOutputVarNode("Out", node)}, {});
}
REGISTER_HANDLER(equal, equal_handler);
REGISTER_HANDLER(logical_not, logical_not_handler);
REGISTER_HANDLER(logical_or, logical_or_handler);
REGISTER_HANDLER(logical_and, logical_and_handler);
REGISTER_HANDLER(greater_than, greater_than_handler);
REGISTER_HANDLER(less_than, less_than_handler);
} // namespace
} // namespace ipu
......
......@@ -98,6 +98,12 @@ Node *matmul_handler(Graph *graph, Node *node) {
if (x_rank == 1) {
perm = std::vector<int64_t>{0};
} else if (x_rank == 2) {
if (!transpose_x && !transpose_y && is_float_equal(alpha, 1.0f)) {
return CreateBaseOp(
graph, node, "popart_matmul",
{GetInputVarNode("X", node), GetInputVarNode("Y", node)},
node->outputs);
}
return CreateGemm(graph, node,
{GetInputVarNode("X", node), GetInputVarNode("Y", node)},
node->outputs, transpose_x, transpose_y, alpha);
......
此差异已折叠。
......@@ -129,7 +129,7 @@ class PredictExecutor : public MlirToRuntimeTranslator {
auto arg = predict_func.getArgument(i);
auto type = arg.getType();
// this param is TensorMap
if (type.isa<infrt::DenseTensorMapType>()) {
if (type.isa<infrt::DenseHostTensorMapType>()) {
auto* value = new host_context::Value(std::move(*map));
arguments_.push_back(value);
AddValue(predict_func.getArgument(i), value);
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册