提交 838ef366 编写于 作者: X xutianbing

add first paddle function example for ContextProjectionForward operator,

by going through Daoyuan's excellent paddle function design.
上级 54a2b1f6
......@@ -17,6 +17,10 @@ if(WITH_TESTING)
# file(GLOB test_files . *OpTest.cpp)
# add_executable(${test_bin} EXCLUDE_FROM_ALL ${test_files})
add_simple_unittest(CrossMapNormalOpTest)
add_unittest(ContextProjectionOpTest
ContextProjectionOpTest.cpp
ContextProjectionOpGpu.cu
../gserver/tests/TestUtil.cpp)
endif()
endif()
......
......@@ -30,6 +30,20 @@ real FuncConfig::get<real>(const std::string& key) const {
return it->second.r;
}
template <>
int FuncConfig::get<int>(const std::string& key) const {
auto it = valueMap_.find(key);
CHECK(it != valueMap_.end()) << "Cannot find value: '" << key << "'";
return it->second.i;
}
template <>
bool FuncConfig::get<bool>(const std::string& key) const {
auto it = valueMap_.find(key);
CHECK(it != valueMap_.end()) << "Cannot find value: '" << key << "'";
return it->second.b;
}
template <>
FuncConfig& FuncConfig::set<size_t>(const std::string& key, size_t v) {
CHECK_EQ(valueMap_.count(key), 0) << "Duplicated value: " << key;
......@@ -44,6 +58,20 @@ FuncConfig& FuncConfig::set<real>(const std::string& key, real v) {
return *this;
}
template <>
FuncConfig& FuncConfig::set<int>(const std::string& key, int v) {
CHECK(valueMap_.count(key) == 0) << "Duplicated value: " << key;
valueMap_[key].i = v;
return *this;
}
template <>
FuncConfig& FuncConfig::set<bool>(const std::string& key, bool v) {
CHECK(valueMap_.count(key) == 0) << "Duplicated value: " << key;
valueMap_[key].b = v;
return *this;
}
ClassRegistrar<FunctionBase> FunctionBase::funcRegistrar_;
} // namespace paddle
......@@ -59,6 +59,8 @@ public:
union value {
size_t s;
real r;
int i;
bool b;
};
template <typename T>
......
......@@ -33,25 +33,33 @@ public:
// init cpu and gpu arguments
auto initArgs = [=](
Arguments& cpuArgs, Arguments& gpuArgs, const Arguments& inArgs) {
for (auto arg : inArgs) {
for (const auto arg : inArgs) {
size_t size = sizeof(real);
for (auto dim : arg.dims_) {
for (const auto dim : arg.dims_) {
size *= dim;
}
cpuMemory.emplace_back(std::make_shared<CpuMemoryHandle>(size));
gpuMemory.emplace_back(std::make_shared<GpuMemoryHandle>(size));
cpuArgs.emplace_back(
Tensor((real*)cpuMemory.back()->getBuf(), arg.dims_));
gpuArgs.emplace_back(
Tensor((real*)gpuMemory.back()->getBuf(), arg.dims_));
// will use an api to refactor this code.
CpuVector cpuVector(size / sizeof(real),
(real*)cpuArgs.back().getData());
GpuVector gpuVector(size / sizeof(real),
(real*)gpuArgs.back().getData());
cpuVector.uniform(0.001, 1);
gpuVector.copyFrom(cpuVector);
if (arg.getData()) {
// todo(tianbing), waste unnecessary mem here
cpuMemory.emplace_back(std::make_shared<CpuMemoryHandle>(size));
gpuMemory.emplace_back(std::make_shared<GpuMemoryHandle>(size));
cpuArgs.emplace_back(Tensor((real*)arg.getData(), arg.dims_));
gpuArgs.emplace_back(Tensor((real*)arg.getData(), arg.dims_));
// already init outside
} else {
cpuMemory.emplace_back(std::make_shared<CpuMemoryHandle>(size));
gpuMemory.emplace_back(std::make_shared<GpuMemoryHandle>(size));
cpuArgs.emplace_back(
Tensor((real*)cpuMemory.back()->getBuf(), arg.dims_));
gpuArgs.emplace_back(
Tensor((real*)gpuMemory.back()->getBuf(), arg.dims_));
// will use an api to refactor this code.
CpuVector cpuVector(size / sizeof(real),
(real*)cpuArgs.back().getData());
GpuVector gpuVector(size / sizeof(real),
(real*)gpuArgs.back().getData());
cpuVector.uniform(0.001, 1);
gpuVector.copyFrom(cpuVector);
}
}
};
initArgs(cpuInputs, gpuInputs, inputs);
......@@ -81,6 +89,10 @@ public:
checkArgs(cpuInouts, gpuInouts);
}
std::shared_ptr<FunctionBase> getCpuFunction() const { return cpu; }
std::shared_ptr<FunctionBase> getGpuFunction() const { return gpu; }
protected:
std::shared_ptr<FunctionBase> cpu;
std::shared_ptr<FunctionBase> gpu;
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "context_projection_op.h"
#include "paddle/math/Matrix.h"
#include "paddle/math/Vector.h"
namespace paddle {
template <>
void ContextProjectionForward<DEVICE_TYPE_CPU>(Tensor& output,
const Tensor& input,
const Tensor& weight,
const Tensor& sequence,
size_t context_length,
int context_start,
size_t begin_pad,
bool is_padding) {
CHECK(output.getData() && input.getData() && sequence.getData());
CHECK_EQ(output.dims_.size(), 2);
CHECK_EQ(input.dims_.size(), 2);
CHECK_EQ(weight.dims_.size(), 2);
CHECK_EQ(sequence.dims_.size(), 1);
auto out_mat = std::make_shared<CpuMatrix>(
output.getData(), output.dims_[0], output.dims_[1]);
const auto in_mat = std::make_shared<CpuMatrix>(
input.getData(), input.dims_[0], input.dims_[1]);
const auto weight_mat =
!weight.getData()
? nullptr
: std::make_shared<CpuMatrix>(
weight.getData(), weight.dims_[0], input.dims_[1]);
CpuIVector seq_vec(sequence.dims_[0],
reinterpret_cast<int*>(sequence.getData()));
CHECK_EQ(out_mat->getWidth(), in_mat->getWidth() * context_length);
const int* starts = seq_vec.getData();
const size_t num_sequences = seq_vec.getSize() - 1;
for (size_t i = 0; i < num_sequences; ++i) {
for (size_t j = 0; j < context_length; ++j) {
int begin = starts[i] + context_start + j;
int end = starts[i + 1] + context_start + j;
int dst_begin = starts[i];
int dst_end = starts[i + 1];
if (begin < starts[i]) {
int64_t pad_size =
std::min(starts[i] - begin, starts[i + 1] - starts[i]);
MatrixPtr mat = out_mat->subMatrix(starts[i], pad_size);
if (is_padding && weight_mat) {
MatrixPtr sub = weight_mat->subMatrix(j, pad_size);
mat->addAtOffset(*sub, j * in_mat->getWidth());
}
dst_begin = starts[i] + pad_size;
begin = starts[i];
}
if (end > starts[i + 1]) {
int64_t pad_size =
std::min(end - starts[i + 1], starts[i + 1] - starts[i]);
MatrixPtr mat = out_mat->subMatrix(starts[i + 1] - pad_size, pad_size);
if (is_padding && weight_mat) {
MatrixPtr sub = weight_mat->subMatrix(
begin_pad + context_start + j - pad_size, pad_size);
mat->addAtOffset(*sub, j * in_mat->getWidth());
}
dst_end = starts[i + 1] - pad_size;
end = starts[i + 1];
}
if (end <= begin) continue;
MatrixPtr src = in_mat->subMatrix(begin, end - begin);
MatrixPtr dst = out_mat->subMatrix(dst_begin, dst_end - dst_begin);
dst->addAtOffset(*src, j * in_mat->getWidth());
}
}
}
/**
* \param inputs[0] input value.
* \param inputs[1] input weight.
* \param inputs[2] input sequence.
* \param outputs[0] output value.
*/
template <DeviceType Device>
class ContextProjectionForwardFunc : public FunctionBase {
public:
void init(const FuncConfig& config) override {
context_length_ = config.get<size_t>("context_length");
context_start_ = config.get<int>("context_start");
begin_pad_ = config.get<size_t>("begin_pad");
is_padding_ = config.get<bool>("is_padding");
}
void calc(const Arguments& inputs,
const Arguments& outputs,
const Arguments& inouts) override {
CHECK_EQ(3, inputs.size());
CHECK_EQ(1, outputs.size());
CHECK_EQ(0, inouts.size());
ContextProjectionForward<Device>((Tensor&)outputs[0],
inputs[0],
inputs[1],
inputs[2],
context_length_,
context_start_,
begin_pad_,
is_padding_);
}
private:
size_t context_length_;
int context_start_;
size_t begin_pad_;
bool is_padding_;
};
REGISTER_TYPED_FUNC(ContextProjectionForward,
CPU,
ContextProjectionForwardFunc);
#ifndef PADDLE_ONLY_CPU
REGISTER_TYPED_FUNC(ContextProjectionForward,
GPU,
ContextProjectionForwardFunc);
#endif
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "Function.h"
namespace paddle {
/**
* \brief Context Projection Forward.
*
* \param[out] outputs output data.
* \param[in] input input data.
* \param[in] weight input weight.
* \param[in] sequence input data.
* \param[in] context_length consecutive rows for concatenation.
* \param[in] begin_pad context start position.
* \param[in] is_padding whether padding 0 or not.
*
*/
template <DeviceType Device>
void ContextProjectionForward(Tensor& output,
const Tensor& input,
const Tensor& weight,
const Tensor& sequence,
size_t context_length,
int context_start,
size_t begin_pad,
bool is_padding);
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "hl_base.h"
#include "context_projection_op.h"
namespace paddle {
template <bool padding>
__global__ void KeContextProjectionForward(const real* input,
const int* sequence,
const real* weight,
real* output,
int input_dim,
int context_length,
int context_start,
int begin_pad) {
int idx = threadIdx.x;
int block_size = blockDim.x;
int sequenceId = blockIdx.x;
int seq_start = sequence[sequenceId];
int seq_end = sequence[sequenceId+1];
real value = 0;
int instances = seq_end - seq_start + context_length - 1;
output += seq_start * input_dim * context_length;
input += seq_start * input_dim;
for (int k = 0; k <= input_dim / block_size; k++) {
if (idx < input_dim) {
for (int i = 0; i < instances; i++) {
// i + context_start;
if ((i + context_start) < 0) {
if (padding) {
value = weight[i * input_dim + idx];
} else {
continue;
}
} else if ((i + context_start) >= (seq_end - seq_start)) {
if (padding) {
value =
weight[(begin_pad + i + context_start - (seq_end - seq_start)) *
input_dim + idx];
} else {
continue;
}
} else {
value = input[(i + context_start) * input_dim + idx];
}
int outx = (i - context_length) < 0 ? i : (context_length - 1);
int outy = (i - context_length) < 0 ? 0 : (i - (context_length - 1));
real* output_r =
output + outy * input_dim * context_length + outx * input_dim;
for (int j = outy; j < seq_end - seq_start; j++) {
output_r[idx] += value;
if (j - outy == outx) break;
output_r += (context_length - 1) * input_dim;
}
}
}
idx += block_size;
}
}
void hl_context_projection_forward(const real* input,
const int* sequence,
real* weight,
real* output,
int num_sequences,
int input_dim,
int context_length,
int context_start,
int begin_pad,
bool is_padding) {
CHECK_NOTNULL(input);
CHECK_NOTNULL(sequence);
CHECK_NOTNULL(output);
CHECK(!is_padding || weight);
int block_size = 128;
int blocks_x = num_sequences;
int blocks_y = 1;
dim3 threads(block_size, 1);
dim3 grid(blocks_x, blocks_y);
if (is_padding) {
KeContextProjectionForward<true><<< grid, threads, 0, STREAM_DEFAULT >>>
(input, sequence, weight, output, input_dim,
context_length, context_start, begin_pad);
} else {
KeContextProjectionForward<false><<< grid, threads, 0, STREAM_DEFAULT >>>
(input, sequence, weight, output, input_dim,
context_length, context_start, begin_pad);
}
CHECK_SYNC("hl_context_projection_forward failed");
}
template <>
void ContextProjectionForward<DEVICE_TYPE_GPU>(Tensor& output,
const Tensor& input,
const Tensor& weight,
const Tensor& sequence,
size_t context_length,
int context_start,
size_t begin_pad,
bool is_padding) {
CHECK(output.getData() && input.getData() && sequence.getData());
CHECK_EQ(output.dims_.size(), 2);
CHECK_EQ(input.dims_.size(), 2);
CHECK_EQ(weight.dims_.size(), 2);
CHECK_EQ(sequence.dims_.size(), 1);
CHECK_EQ(output.dims_[1], input.dims_[1] * context_length);
hl_context_projection_forward(input.getData(),
reinterpret_cast<int*>(sequence.getData()),
weight.getData(),
output.getData(),
sequence.dims_[0] - 1,
input.dims_[1],
context_length,
context_start,
begin_pad,
is_padding);
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "FunctionTest.h"
#include "paddle/gserver/tests/TestUtil.h"
#include "paddle/math/Matrix.h"
using namespace paddle; // NOLINT
void testMatrixProjectionForward(int context_start,
size_t context_length,
bool is_padding,
size_t batch_size,
size_t input_dim) {
size_t pad = std::max(0, -context_start) +
std::max(0, (int)(context_start + context_length - 1));
if (pad == 0) is_padding = false;
FunctionCompare compare("ContextProjectionForward",
FuncConfig()
.set("context_length", context_length)
.set("context_start", context_start)
.set("begin_pad", std::max(0, -context_start))
.set("is_padding", is_padding));
CpuMatrix cpu_in(batch_size, input_dim);
cpu_in.randomizeUniform();
GpuMatrix gpu_in(batch_size, input_dim);
gpu_in.copyFrom(cpu_in);
auto cpu_weight =
is_padding ? std::make_shared<CpuMatrix>(pad, input_dim) : nullptr;
auto gpu_weight =
is_padding ? std::make_shared<GpuMatrix>(pad, input_dim) : nullptr;
if (is_padding) {
cpu_weight->randomizeUniform();
gpu_weight->copyFrom(*cpu_weight);
}
IVectorPtr cpu_seq;
generateSequenceStartPositions(batch_size, cpu_seq);
IVectorPtr gpu_seq = IVector::create(cpu_seq->getSize(), true);
gpu_seq->copyFrom(*cpu_seq);
CpuMatrix cpu_out(batch_size, input_dim * context_length);
GpuMatrix gpu_out(batch_size, input_dim * context_length);
cpu_out.randomizeUniform();
gpu_out.copyFrom(cpu_out);
compare.getCpuFunction()->calc(
{Tensor(cpu_in.getData(), Dims{batch_size, input_dim}),
Tensor(cpu_weight ? cpu_weight->getData() : nullptr,
Dims{pad, input_dim}),
Tensor(reinterpret_cast<real*>(cpu_seq->getData()),
Dims{cpu_seq->getSize()})},
{Tensor(cpu_out.getData(), Dims{batch_size, input_dim * context_length})},
{});
compare.getGpuFunction()->calc(
{Tensor(gpu_in.getData(), Dims{batch_size, input_dim}),
Tensor(gpu_weight ? gpu_weight->getData() : nullptr,
Dims{pad, input_dim}),
Tensor(reinterpret_cast<real*>(gpu_seq->getData()),
Dims{gpu_seq->getSize()})},
{Tensor(gpu_out.getData(), Dims{batch_size, input_dim * context_length})},
{});
autotest::TensorCheckEqual(cpu_out, gpu_out);
}
TEST(ContextProjectionForward, projection) {
for (auto context_start : {-5, -3, -1, 0, 3}) {
for (auto context_length : {1, 2, 5, 7}) {
for (auto trainable_padding : {false, true}) {
for (auto batch_size : {1, 2, 5, 20, 100}) {
for (auto input_dim : {15, 32, 63, 128, 200}) {
VLOG(3) << " context_start=" << context_start
<< " context_length=" << context_length
<< " trainable_padding=" << trainable_padding
<< " batch_size=" << batch_size
<< " input_dim=" << input_dim;
testMatrixProjectionForward(context_start,
context_length,
trainable_padding,
batch_size,
input_dim);
}
}
}
}
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册