提交 c2e851f7 编写于 作者: J JiabinYang

test=develop, remove sparse bias and add prefetch and related tests

上级 c35fdf15
......@@ -32,7 +32,7 @@ namespace paddle {
namespace operators {
namespace distributed {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows;
using DDim = framework::DDim;
......@@ -120,8 +120,8 @@ static void MergeMultipleVarsIntoOneBySection(
PADDLE_ENFORCE_GT(
out_tensor->numel(), 0,
"When calling this method, the Tensor's numel must larger than zero. "
"Please check Tensor::Resize has been called first.");
"When calling this method, the LoDTensor's numel must larger than zero. "
"Please check LoDTensor::Resize has been called first.");
auto* out_tensor_data = out_tensor->mutable_data<float>(id_tensor.place());
......@@ -144,7 +144,7 @@ static void MergeMultipleVarsIntoOneBySection(
auto row_numel = dims[1];
for (size_t i = 0; i < dims[0]; ++i) {
for (int64_t i = 0; i < dims[0]; ++i) {
auto id = ids_in_this_section[i];
auto origin_id = id + abs_sections[section_idx];
auto& offsets = id_to_offset[origin_id];
......@@ -201,7 +201,7 @@ void prefetch(const std::string& id_name, const std::string& out_name,
std::vector<int64_t> ids_vector;
if (platform::is_cpu_place(id_tensor.place())) {
auto* id_data = id_tensor.data<int64_t>();
for (size_t i = 0; i < id_tensor.numel(); ++i) {
for (int64_t i = 0; i < id_tensor.numel(); ++i) {
ids_vector.push_back(id_data[i]);
}
} else {
......@@ -209,7 +209,7 @@ void prefetch(const std::string& id_name, const std::string& out_name,
PADDLE_THROW("paddle is not compiled with CUDA!");
#else
auto cpu_place = platform::CPUPlace();
framework::Tensor cpu_tensor;
framework::LoDTensor cpu_tensor;
auto* cpu_tensor_data =
cpu_tensor.mutable_data<int64_t>(id_tensor.dims(), cpu_place);
auto stream =
......
......@@ -30,6 +30,30 @@ void prefetch(const std::string& id_name, const std::string& out_name,
const framework::ExecutionContext& context,
const framework::Scope& scope);
template <typename T>
void prefetch_with_reconstruct(const std::string& id_name,
const std::string& out_name,
const std::vector<std::string>& table_names,
const std::vector<std::string>& epmap,
const std::vector<int>& height_sections,
const framework::ExecutionContext& context,
const framework::Scope& scope,
framework::LoDTensor* original) {
prefetch(id_name, out_name, table_names, epmap, height_sections, context,
scope);
auto& out = scope.FindVar(out_name)->Get<framework::LoDTensor>();
auto& ids = scope.FindVar(id_name)->Get<framework::LoDTensor>();
auto* original_value = original->data<T>();
auto* out_value = out.data<T>();
size_t original_width = original->numel() / original->dims()[0];
for (int64_t i = 0; i < ids.numel(); i++) {
const T* out_rows = out_value + original_width * i;
T* original_row = original_value + original_width * ids.data<int64_t>()[i];
std::memcpy(original_row, out_rows, original_width * sizeof(T));
}
}
}; // namespace distributed
}; // namespace operators
}; // namespace paddle
......@@ -67,6 +67,11 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("PreOut"),
"Output(PreOut) should not be null.");
auto with_prefetch = ctx->Attrs().Get<bool>("remote_prefetch");
if (with_prefetch) {
PADDLE_ENFORCE(ctx->HasOutput("W_Out"),
"Output(W_Out) should not be null.");
}
const int64_t batch_size = ctx->GetInputDim("X")[0];
std::vector<int64_t> output_shape({batch_size, 1});
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
......@@ -96,7 +101,7 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Label",
"(LoDTensor, required), The labels of training data. It's a"
"tensor with shape [N, 1].");
AddInput("PTable",
AddInput("PathTable",
"(LoDTensor, optional), The Path Table from root to current word"
"it should have shape like [N, L], L is the length of the Path")
.AsDispensable();
......@@ -120,8 +125,30 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
"[batch_size, code_length], where code_length represents the "
"maximum path length from root to leaf nodes.")
.AsIntermediate();
AddOutput(
"W_Out",
"(LoDTensor, optinal) using input 'W' as Output to make it mutable"
"When we are using prefetch")
.AsIntermediate();
AddAttr<AttrType>("num_classes", "(int, optional), The number of classes")
.SetDefault(2);
// for parameter prefetch
AddAttr<bool>("remote_prefetch", "").SetDefault(false);
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<std::vector<int>>("height_sections",
"Height for each output SelectedRows.")
.SetDefault(std::vector<int>({}));
AddAttr<std::vector<std::string>>(
"epmap",
"(string vector, default 127.0.0.1:6164)"
"Server endpoints in the order of input variables for mapping")
.SetDefault({});
AddAttr<std::vector<std::string>>(
"table_names",
"(string vector, the splited table names that will be fetched from "
"parameter server)"
"in the order of input variables for mapping")
.SetDefault({});
AddComment(R"DOC(
The hierarchical sigmoid operator organize the classes into a binary tree.
At each node, a sigmoid function is used to calculate the probability of
......@@ -191,23 +218,17 @@ class HierarchicalSigmoidGradOpGradVarTypeInference
<< " is set to SelectedRows";
block->Var(w_grad_var_name)
->SetType(framework::proto::VarType::SELECTED_ROWS);
if (hasBias) {
VLOG(30) << "hierarchical_sigmoid_grad op "
<< framework::GradVarName("Bias") << " is set to SelectedRows";
block->Var(bias_grad_var_name)
->SetType(framework::proto::VarType::SELECTED_ROWS);
}
} else {
VLOG(30) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W")
<< " is set to LoDTensor";
block->Var(w_grad_var_name)
->SetType(framework::proto::VarType::LOD_TENSOR);
if (hasBias) {
VLOG(30) << "hierarchical_sigmoid_grad op "
<< framework::GradVarName("Bias") << " is set to LoDTensor";
block->Var(bias_grad_var_name)
->SetType(framework::proto::VarType::LOD_TENSOR);
}
}
if (hasBias) {
VLOG(30) << "hierarchical_sigmoid_grad op "
<< framework::GradVarName("Bias") << " is set to LoDTensor";
block->Var(bias_grad_var_name)
->SetType(framework::proto::VarType::LOD_TENSOR);
}
block->Var(w_grad_var_name)->SetDataType(block->Var("W")->GetDataType());
}
......
......@@ -14,7 +14,9 @@ limitations under the License. */
#pragma once
#include <iostream>
#include <iterator>
#include <set>
#include <string>
#include <vector>
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/framework/op_registry.h"
......@@ -24,6 +26,10 @@ limitations under the License. */
#include "paddle/fluid/operators/math/matrix_bit_code.h"
#include "paddle/fluid/platform/transform.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/fluid/operators/distributed/parameter_prefetch.h"
#endif
namespace paddle {
namespace operators {
......@@ -49,13 +55,55 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
auto& in = detail::Ref(ctx.Input<framework::LoDTensor>("X"));
auto& w = detail::Ref(ctx.Input<framework::LoDTensor>("W"));
auto* path = ctx.Input<framework::LoDTensor>("PTable");
auto* path = ctx.Input<framework::LoDTensor>("PathTable");
auto* code = ctx.Input<framework::LoDTensor>("PathCode");
auto& label = detail::Ref(ctx.Input<framework::LoDTensor>("Label"));
auto* bias = ctx.Input<framework::LoDTensor>("Bias");
auto* out = ctx.Output<framework::LoDTensor>("Out");
auto* pre_out = ctx.Output<framework::LoDTensor>("PreOut");
size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
// for remote prefetch
auto epmap = ctx.Attr<std::vector<std::string>>("epmap");
if (!epmap.empty()) {
// if epmap is not empty, then the parameter will be fetched from remote
// parameter
// server
auto height_sections = ctx.Attr<std::vector<int>>("height_sections");
auto table_names = ctx.Attr<std::vector<std::string>>("table_names");
VLOG(3) << "path type is " << path->type().name();
std::vector<int64_t> real_rows = PathToRows(*path);
framework::Scope& local_scope = ctx.scope().NewScope();
auto* ids = local_scope.Var("Ids@Prefetch");
auto* x_tensor = ids->GetMutable<framework::LoDTensor>();
x_tensor->mutable_data<int64_t>(
framework::make_ddim({static_cast<int64_t>(real_rows.size()), 1}),
ctx.GetPlace());
// copy.
std::memcpy(x_tensor->data<int64_t>(), real_rows.data(),
real_rows.size() * sizeof(int64_t));
framework::DDim w_dims = ctx.Input<Tensor>("W")->dims();
w_dims[0] = x_tensor->dims()[0];
auto* w_tensor =
local_scope.Var("W@Prefetch")->GetMutable<framework::LoDTensor>();
w_tensor->Resize(w_dims);
#ifdef PADDLE_WITH_DISTRIBUTE
// w_Out is set to used by prefetch, never change it in other cases
auto* w_out = ctx.Output<framework::LoDTensor>("W_Out");
operators::distributed::prefetch_with_reconstruct<T>(
"Ids@Prefetch", "W@Prefetch", table_names, epmap, height_sections,
ctx, local_scope, w_out);
#else
PADDLE_THROW(
"paddle is not compiled with distribute support, can not do "
"parameter prefetch!");
#endif
}
bool is_custom = false;
if (path) {
is_custom = true;
......@@ -116,9 +164,8 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
auto& in = detail::Ref(ctx.Input<framework::LoDTensor>("X"));
auto& w = detail::Ref(ctx.Input<framework::LoDTensor>("W"));
auto* path = ctx.Input<framework::LoDTensor>("PTable");
auto* path = ctx.Input<framework::LoDTensor>("PathTable");
auto* code = ctx.Input<framework::LoDTensor>("PathCode");
auto* bias = ctx.Input<framework::LoDTensor>("Bias");
auto* in_grad =
ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
bool is_sparse = ctx.Attr<bool>("is_sparse");
......@@ -165,15 +212,14 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
pre_out_grad_mat * out_grad_mat.broadcast(bcast);
// TODO(guosheng): multiply pre_out_grad with subgradient of clipping to
// be consistent with the clipping in forward.
auto* bias_grad =
ctx.Output<framework::LoDTensor>(framework::GradVarName("Bias"));
if (bias_grad) {
bias_grad->mutable_data<T>(ctx.GetPlace());
zero(dev_ctx, bias_grad, static_cast<T>(0.0));
bit_code->AddGrad(pre_out_grad, bias_grad);
}
if (!is_sparse) {
auto* bias_grad =
ctx.Output<framework::LoDTensor>(framework::GradVarName("Bias"));
if (bias_grad) {
bias_grad->mutable_data<T>(ctx.GetPlace());
zero(dev_ctx, bias_grad, static_cast<T>(0.0));
bit_code->AddGrad(pre_out_grad, bias_grad);
}
auto* w_grad =
ctx.Output<framework::LoDTensor>(framework::GradVarName("W"));
w_grad->mutable_data<T>(ctx.GetPlace());
......@@ -192,21 +238,6 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
w_grad_value->mutable_data<T>(temp_dim, ctx.GetPlace());
zero(dev_ctx, w_grad_value, static_cast<T>(0.0));
auto* bias_grad =
ctx.Output<framework::SelectedRows>(framework::GradVarName("Bias"));
if (bias_grad) {
bias_grad->set_rows(real_rows);
// build ids -> rows index map
bias_grad->SyncIndex();
bias_grad->set_height(bias->dims()[0]);
auto* bias_grad_value = bias_grad->mutable_value();
std::vector<int64_t> dims = {static_cast<int64_t>(real_rows.size()),
bias->dims()[1]};
bias_grad_value->mutable_data<T>(framework::make_ddim(dims),
ctx.GetPlace());
zero(dev_ctx, bias_grad_value, static_cast<T>(0.0));
bit_code->AddGrad(pre_out_grad, bias_grad);
}
bit_code->MulGradWeight(pre_out_grad, w_grad, in);
}
bit_code->MulGradError(pre_out_grad, w, in_grad);
......
......@@ -48,23 +48,6 @@ void MatrixBitCodeFunctor<T>::AddGrad(const framework::Tensor& tmat,
}
}
template <typename T>
void MatrixBitCodeFunctor<T>::AddGrad(const framework::Tensor& tmat,
framework::SelectedRows* vec) {
size_t batch_size = tmat.dims()[0];
size_t width = tmat.dims()[1];
for (size_t i = 0; i < batch_size; ++i) {
auto code = code_table_->get_code(i);
int code_length = code->get_length();
for (int j = 0; j < code_length; ++j) {
size_t index = code->calc_index(j);
int64_t row_index = vec->GetIndexFromId(static_cast<int64_t>(index));
vec->mutable_value()->data<T>()[row_index] +=
tmat.data<T>()[i * width + j];
}
}
}
template <typename T>
void MatrixBitCodeFunctor<T>::Sum(const framework::Tensor& tmat,
framework::Tensor* sum, T scale_sum) {
......
......@@ -139,11 +139,11 @@ class SimpleCode : public Code {
template <typename T>
class CustomCode : public Code {
public:
CustomCode(const framework::Tensor& ptable, const framework::Tensor& pcode,
const int64_t* ids, int index)
CustomCode(const framework::Tensor& path_table,
const framework::Tensor& path_code, const int64_t* ids, int index)
: ids_(ids), index_(index) {
ptable_ = ptable.Slice(index, index + 1);
pcode_ = pcode.Slice(index, index + 1);
ptable_ = path_table.Slice(index, index + 1);
pcode_ = path_code.Slice(index, index + 1);
}
/**
* Here the id of root shoud be 1 rather than 0, thus the encoding of class c
......@@ -195,9 +195,9 @@ class SimpleCodeTable : public CodeTable {
template <typename T>
class CustomCodeTable : public CodeTable {
public:
CustomCodeTable(const framework::Tensor& ptable,
const framework::Tensor& pcode, const int64_t* ids)
: ptable_(ptable), pcode_(pcode), ids_(ids) {}
CustomCodeTable(const framework::Tensor& path_table,
const framework::Tensor& path_code, const int64_t* ids)
: ptable_(path_table), pcode_(path_code), ids_(ids) {}
std::unique_ptr<Code> get_code(int64_t code) const {
std::unique_ptr<Code> coder(new CustomCode<T>(ptable_, pcode_, ids_, code));
......@@ -223,11 +223,11 @@ class MatrixBitCodeFunctor {
ids_(ids),
code_table_(new SimpleCodeTable(num_classes, ids)) {}
MatrixBitCodeFunctor(const framework::Tensor& ptable,
const framework::Tensor& pcode, const int64_t* ids)
: num_classes_(static_cast<size_t>(ptable.dims()[1])),
MatrixBitCodeFunctor(const framework::Tensor& path_table,
const framework::Tensor& path_code, const int64_t* ids)
: num_classes_(static_cast<size_t>(path_table.dims()[1])),
ids_(ids),
code_table_(new CustomCodeTable<int64_t>(ptable, pcode, ids)) {}
code_table_(new CustomCodeTable<int64_t>(path_table, path_code, ids)) {}
/* For j < code_length
tmat(i, j) += vec(0, index(i, j))
*/
......@@ -238,11 +238,6 @@ class MatrixBitCodeFunctor {
*/
void AddGrad(const framework::Tensor& tmat, framework::Tensor* vec);
/* For selected rows For j < code_length
vec(0, index(i, j)) += tmat(i, j)
*/
void AddGrad(const framework::Tensor& tmat, framework::SelectedRows* vec);
/* For j < code_length
sum(i, 0) = \sum_j bit(i, j) * tmat(i, j)
*/
......
......@@ -4931,6 +4931,9 @@ def hsigmoid(input,
pass
weights = None
remote_prefetch = False
if os.environ.get('PADDLE_ENABLE_REMOTE_PREFETCH'):
remote_prefetch = True
if not is_custom:
weights = helper.create_parameter(
......@@ -4947,7 +4950,7 @@ def hsigmoid(input,
inputs = {
"X": input,
"W": weights,
"PTable": path_table,
"PathTable": path_table,
"PathCode": path_code,
"Label": label
}
......@@ -4970,9 +4973,13 @@ def hsigmoid(input,
type="hierarchical_sigmoid",
inputs=inputs,
outputs={"Out": out,
"PreOut": pre_out},
attrs={"num_classes": num_classes,
"is_sparse": is_sparse})
"PreOut": pre_out,
"W_Out": weights},
attrs={
"num_classes": num_classes,
"is_sparse": is_sparse,
"remote_prefetch": remote_prefetch
})
return out
......@@ -7440,7 +7447,7 @@ def brelu(x, t_min=0.0, t_max=24.0, name=None):
Examples:
.. code-block:: python
.. code-block:: python
x = fluid.layers.data(name="x", shape=[2,3,16,16], dtype="float32")
y = fluid.layers.brelu(x, t_min=1.0, t_max=20.0)
......
......@@ -185,7 +185,7 @@ class TestHSigmoidOpSparse(OpTest):
self.inputs = {
'X': x,
'W': w,
'PTable': path_table,
'PathTable': path_table,
'PathCode': path_code,
'Label': label,
'Bias': bias
......@@ -287,7 +287,7 @@ class TestHSigmoidOpWithCostumTree(OpTest):
self.inputs = {
'X': x,
'W': w,
'PTable': path_table,
'PathTable': path_table,
'PathCode': path_code,
'Label': label,
'Bias': bias
......@@ -324,7 +324,7 @@ class TestHSigmoidOpWithCostumTreeWithoutBias(OpTest):
self.inputs = {
'X': x,
'W': w,
'PTable': path_table,
'PathTable': path_table,
'PathCode': path_code,
'Label': label,
}
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import signal
import time
import unittest
from multiprocessing import Process
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.op import Operator
from paddle.fluid.framework import Program, program_guard
def run_pserver(pserver_id, use_cuda, sync_mode):
scope = fluid.core.Scope()
program = Program()
with fluid.scope_guard(scope):
with program_guard(program, startup_program=Program()):
# create table parameter in scope
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
# create and initialize Param Variable
param = scope.var('table').get_tensor()
param_array = np.ones((5, 8)).astype("float32")
for i in range(len(param_array)):
param_array[i] *= param_array[i] * i + pserver_id * 10 + 1
param.set(param_array, place)
optimize_block = program._create_block(program.global_block().idx)
program.global_block().append_op(
type="listen_and_serv",
inputs={'X': []},
outputs={},
attrs={
"optimize_blocks": [optimize_block],
"endpoint": '127.0.0.1:0',
"Fanin": 1,
"sync_mode": True,
"grad_to_block_id": []
})
exe = fluid.Executor(place)
exe.run(program)
class TestListenAndServOp(unittest.TestCase):
def setUp(self):
self.ps_timeout = 5
def _start_pserver(self, pserver_id, use_cuda, sync_mode, pserver_func):
p = Process(target=pserver_func, args=(pserver_id, use_cuda, sync_mode))
p.daemon = True
p.start()
return p
def _wait_ps_ready(self, pid):
start_left_time = self.ps_timeout
sleep_time = 0.5
while True:
assert start_left_time >= 0, "wait ps ready failed"
time.sleep(sleep_time)
try:
# the listen_and_serv_op would touch a file which contains the listen port
# on the /tmp directory until it was ready to process all the RPC call.
os.stat("/tmp/paddle.%d.port" % pid)
return
except os.error:
start_left_time -= sleep_time
def _get_pserver_port(self, pid):
with open("/tmp/paddle.%d.port" % pid, 'r') as f:
port = int(f.read().strip())
return port
def _run_hsigmoid_op_one_pserver(self, place, port):
scope = fluid.core.Scope()
program = Program()
with fluid.scope_guard(scope):
with program_guard(program, startup_program=Program()):
x = scope.var('X').get_tensor()
x_array = np.random.random((4, 8)).astype("float32") * 2
x.set(x_array, place)
# create and initialize Param Variable
param = scope.var('W').get_tensor()
param_array = np.zeros((5, 8)).astype("float32") * 2
param.set(param_array, place)
path_table = scope.var('PathTable').get_tensor()
path_table_array = np.array(
[(0, 2, -1, -1, -1), (0, 1, 2, -1, -1), (0, 1, 4, -1, -1),
(0, 2, -1, -1, -1)]).astype(
"int64"
) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf)
path_table.set(path_table_array, place)
path_code = scope.var('PathCode').get_tensor()
path_code_array = np.array(
[(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), (1, 0, 0, -1, -1),
(0, 1, -1, -1, -1)]).astype("int64") #np.array to store
path_code.set(path_code_array, place)
label = scope.var('Label').get_tensor()
label_array = np.array([0, 1, 4, 5])
label.set(label_array, place)
bias = scope.var('Bias').get_tensor()
bias_array = np.random.random((5, 1)).astype("float32")
bias.set(bias_array, place)
out = scope.var('Out').get_tensor()
pre_out = scope.var('PreOut').get_tensor
w_out = scope.var('W_Out').get_tensor()
w_out.set(param_array, place)
emaps = ['127.0.0.1:' + str(port)]
table_names = ['table']
height_sections = [2]
# create and run sgd operator
hsigmoid_op = Operator(
"hierarchical_sigmoid",
X='X',
W='W',
PathTable='PathTable',
PathCode='PathCode',
Label='Label',
Bias='Bias',
Out='Out',
PreOut='PreOut',
W_Out='W_Out',
remote_prefetch=True,
epmap=emaps,
table_names=table_names,
height_sections=height_sections)
hsigmoid_op.run(scope, place)
# get and compare result
result_array = np.array(w_out)
self.assertEqual(list(result_array.shape), [5, 8])
correct = None
for i in range(5):
if i != 3:
correct = np.full((1, 8), i + 1).astype("float32")
self.assertTrue((result_array[i] == correct).all())
else:
correct = np.full((1, 8), 0).astype("float32")
self.assertTrue((result_array[i] == correct).all())
def _run_hsigmoid_op_two_pserver(self, place, port0, port1):
scope = fluid.core.Scope()
program = Program()
with fluid.scope_guard(scope):
with program_guard(program, startup_program=Program()):
x = scope.var('X').get_tensor()
x_array = np.random.random((4, 8)).astype("float32") * 2
x.set(x_array, place)
# create and initialize Param Variable
param = scope.var('W').get_tensor()
param_array = np.zeros((5, 8)).astype("float32") * 2
param.set(param_array, place)
path_table = scope.var('PathTable').get_tensor()
path_table_array = np.array(
[(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1),
(0, 2, -1, -1, -1)]).astype(
"int64"
) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf)
path_table.set(path_table_array, place)
path_code = scope.var('PathCode').get_tensor()
path_code_array = np.array(
[(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), (1, 0, 0, -1, -1),
(0, 1, -1, -1, -1)]).astype("int64") #np.array to store
path_code.set(path_code_array, place)
label = scope.var('Label').get_tensor()
label_array = np.array([0, 1, 4, 5])
label.set(label_array, place)
bias = scope.var('Bias').get_tensor()
bias_array = np.random.random((5, 1)).astype("float32")
bias.set(bias_array, place)
out = scope.var('Out').get_tensor()
pre_out = scope.var('PreOut').get_tensor
w_out = scope.var('W_Out').get_tensor()
w_out.set(param_array, place)
emaps = ['127.0.0.1:' + str(port0), '127.0.0.1:' + str(port1)]
table_names = ['table', 'table']
height_sections = [2, 3]
# create and run sgd operator
hsigmoid_op = Operator(
"hierarchical_sigmoid",
X='X',
W='W',
PathTable='PathTable',
PathCode='PathCode',
Label='Label',
Bias='Bias',
Out='Out',
PreOut='PreOut',
W_Out='W_Out',
remote_prefetch=True,
epmap=emaps,
table_names=table_names,
height_sections=height_sections)
hsigmoid_op.run(scope, place)
# get and compare result
result_array = np.array(w_out)
self.assertEqual(list(result_array.shape), [5, 8])
correct = None
for i in range(5):
if i < 2:
correct = np.full((1, 8), i + 1).astype("float32")
self.assertTrue((result_array[i] == correct).all())
else:
correct = np.full((1, 8), i + 9).astype("float32")
self.assertTrue((result_array[i] == correct).all())
def test_hsigmoid_op_remote(self):
os.environ['PADDLE_ENABLE_REMOTE_PREFETCH'] = "1"
# run pserver on CPU in sync mode
p0 = self._start_pserver(0, False, True, run_pserver)
self._wait_ps_ready(p0.pid)
port0 = self._get_pserver_port(p0.pid)
p1 = self._start_pserver(1, False, True, run_pserver)
self._wait_ps_ready(p1.pid)
port1 = self._get_pserver_port(p1.pid)
places = [core.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
for place in places:
self._run_hsigmoid_op_one_pserver(place, port0)
self._run_hsigmoid_op_two_pserver(place, port0, port1)
# raise SIGTERM to pserver
os.kill(p0.pid, signal.SIGINT)
p0.join()
os.kill(p1.pid, signal.SIGINT)
p1.join()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册