提交 cfad83ce 编写于 作者: Y yangyaming

Add MulValueLayer.

上级 6f43c936
......@@ -45,6 +45,7 @@ if(WITH_GPU)
add_simple_unittest(BlockExpandOpTest)
add_simple_unittest(CropOpTest)
add_simple_unittest(SwitchOpTest)
add_simple_unittest(MulValueOpTest)
endif()
add_simple_unittest(Im2ColTest)
......
......@@ -110,6 +110,7 @@ public:
function2_(FunctionBase::funcRegistrar_.createByType(name2)) {
function1_->init(config);
function2_->init(config);
initArgsCallBack_ = nullptr;
}
~Compare2Function() {}
......@@ -170,6 +171,10 @@ public:
*seq2_));
}
void registerInitCallBack(std::function<void(BufferArg&, size_t)> callback) {
initArgsCallBack_ = callback;
}
// output need only contains shape, do not contains data.
void addOutputs(const BufferArg& output, ArgType argType = ASSIGN_TO) {
size_t size =
......@@ -340,6 +345,10 @@ protected:
initArg(*func1Inputs_[i]);
}
if (initArgsCallBack_ != nullptr) {
initArgsCallBack_(*func1Inputs_[i], i);
}
copyArg_(*func1Inputs_[i], *func2Inputs_[i]);
}
}
......@@ -386,6 +395,7 @@ protected:
std::shared_ptr<SequenceIdArg> seq1_;
std::shared_ptr<SequenceIdArg> seq2_;
test::CopyArgument<DType1, DType2> copyArg_;
std::function<void(BufferArg&, size_t)> initArgsCallBack_;
};
class CpuGpuFuncCompare
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "MulValueOp.h"
#include "paddle/function/TensorShape.h"
namespace paddle {
template <>
void MulValue<DEVICE_TYPE_CPU>(real* outputs,
const real* inputs,
const real* indices,
const TensorShape shape,
const FuncConfig& conf) {
real value = conf.get<real>("value");
int number = shape[0];
int channel = shape[1];
int height = shape[2];
int width = shape[3];
memcpy(outputs, inputs, number * channel * height * width * sizeof(real));
for (int n = 0; n < number; ++n) {
// indices start from 1
int offset = n * 6;
for (int c = indices[offset] - 1; c < indices[offset + 1]; ++c) {
for (int h = indices[offset + 2] - 1; h < indices[offset + 3]; ++h) {
for (int w = indices[offset + 4] - 1; w < indices[offset + 5]; ++w) {
int idx = ((n * channel + c) * height + h) * width + w;
outputs[idx] *= value;
}
}
}
}
}
template <>
void MulValueGrad<DEVICE_TYPE_CPU>(const real* inGrad,
real* outGrad,
const real* indices,
const TensorShape shape,
const FuncConfig& conf) {
real value = conf.get<real>("value");
int number = shape[0];
int channel = shape[1];
int height = shape[2];
int width = shape[3];
for (int n = 0; n < number; ++n) {
for (int c = 0; c < channel; ++c) {
for (int h = 0; h < height; ++h) {
for (int w = 0; w < width; ++w) {
int idx = ((n * channel + c) * height + h) * width + w;
int offset = n * 6;
if (c >= (indices[offset] - 1) && c <= (indices[offset + 1] - 1) &&
h >= (indices[offset + 2] - 1) &&
h <= (indices[offset + 3] - 1) &&
w >= (indices[offset + 4] - 1) &&
w <= (indices[offset + 5] - 1)) {
outGrad[idx] += inGrad[idx] * value;
} else {
outGrad[idx] += inGrad[idx];
}
}
}
}
}
}
/**
* \brief For each instance, MulValue can be used to multiply a value to a
* specified sub continuous region. By providing start index and end
* index for C/H/W, you can specify the location and shape of the region.
*
* Argument in this Function:
* \param inputs A 4-D tensor with shape [N, C, H, W], only one input.
* \param indices A 2-D tensor with shape [N, 6], indicates the sub region.
* \param outputs A 4-D tensor with same shape as inputs, output value.
*/
template <DeviceType Device>
class MulValueFunc : public FunctionBase {
public:
void init(const FuncConfig& config) override { conf_ = config; }
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(2UL, inputs.size());
CHECK_EQ(1UL, outputs.size());
CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
TensorShape shape = inputs[0].shape();
MulValue<Device>(outputs[0].data<real>(),
inputs[0].data<real>(),
inputs[1].data<real>(),
shape,
conf_);
}
private:
FuncConfig conf_;
};
/**
* \brief The backward propagation of MulValue Function.
*
* Argument in this Function:
* \param inputs A 4-D tensor with shape [N, C, H, W], output gradient.
* \param indices A 2-D tensor with shape [N, 6], indicates the sub region.
* \param outputs A 4-D tensor with shape [N, C, H, W], gradient of input value.
*/
template <DeviceType Device>
class MulValueGradFunc : public FunctionBase {
public:
void init(const FuncConfig& config) override { conf_ = config; }
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(2UL, inputs.size());
CHECK_EQ(1UL, outputs.size());
CHECK_EQ(outputs[0].getArgType(), ADD_TO);
TensorShape shape = inputs[0].shape();
MulValueGrad<Device>(inputs[0].data<real>(),
outputs[0].data<real>(),
inputs[1].data<real>(),
shape,
conf_);
}
private:
FuncConfig conf_;
};
REGISTER_TYPED_FUNC(MulValue, CPU, MulValueFunc);
REGISTER_TYPED_FUNC(MulValueGrad, CPU, MulValueGradFunc);
#ifdef PADDLE_WITH_CUDA
REGISTER_TYPED_FUNC(MulValue, GPU, MulValueFunc);
REGISTER_TYPED_FUNC(MulValueGrad, GPU, MulValueGradFunc);
#endif
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "Function.h"
namespace paddle {
/**
* \brief Function to multiply a value to values in specified sub continuous
* region. Indices must be provided to indcate the location and shape of
* the region and the multiplied value is passed by configure variable.
*
*
* \param[out] outputs Output value.
* \param[in] inputs Input data which contains NCHW information.
* \param[in] indices Indices data to indcate the sub region.
* \param[in] shape Tensor shape of input value.
* \param[in] conf Configure variable which contains the multiplied value.
*/
template <DeviceType Device>
void MulValue(real* outputs,
const real* inputs,
const real* indices,
const TensorShape shape,
const FuncConfig& conf);
/**
* \brief Back propagation function of MulValue.
*
* \param[out] inGrad Gradients of previous layer.
* \param[in] outGrad Output gradient.
* \param[in] indices Indices data.
* \param[in] shape The Shape of input tensor.
* \param[in] conf Configure variable.
*/
template <DeviceType Device>
void MulValueGrad(const real* inGrad,
real* outGrad,
const real* indices,
const TensorShape shape,
const FuncConfig& conf);
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "MulValueOp.h"
#include "hl_base.h"
namespace paddle {
__global__ void KeMulValue(real* outputs,
const real* inputs,
const real* indices,
real value,
int channel,
int height,
int width,
int nthreads) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < nthreads) {
const int w = idx % width;
const int h = (idx / width) % height;
const int c = (idx / width / height) % channel;
const int n = idx / width / height / channel;
const int offset = n * 6;
if (c >= (indices[offset] - 1) && c <= (indices[offset + 1] - 1) &&
h >= (indices[offset + 2] - 1) && h <= (indices[offset + 3] - 1) &&
w >= (indices[offset + 4] - 1) && w <= (indices[offset + 5] - 1)) {
outputs[idx] = inputs[idx] * value;
} else {
outputs[idx] = inputs[idx];
}
}
}
template <>
void MulValue<DEVICE_TYPE_GPU>(real* outputs,
const real* inputs,
const real* indices,
const TensorShape shape,
const FuncConfig& conf) {
real value = conf.get<real>("value");
int number = shape[0];
int channel = shape[1];
int height = shape[2];
int width = shape[3];
size_t nth = number * channel * height * width;
int blockSize = 1024;
int gridSize = (nth + blockSize - 1) / blockSize;
KeMulValue<<<gridSize, blockSize, 0, STREAM_DEFAULT>>>(
outputs, inputs, indices, value, channel, height, width, nth);
CHECK_SYNC("MulValue");
}
__global__ void KeMulValueDiff(const real* inGrad,
real* outGrad,
const real* indices,
real value,
int channel,
int height,
int width,
int nthreads) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < nthreads) {
const int w = idx % width;
const int h = (idx / width) % height;
const int c = (idx / width / height) % channel;
const int n = idx / width / height / channel;
const int offset = n * 6;
if (c >= (indices[offset] - 1) && c <= (indices[offset + 1] - 1) &&
h >= (indices[offset + 2] - 1) && h <= (indices[offset + 3] - 1) &&
w >= (indices[offset + 4] - 1) && w <= (indices[offset + 5] - 1)) {
outGrad[idx] += inGrad[idx] * value;
} else {
outGrad[idx] += inGrad[idx];
}
}
}
template <>
void MulValueGrad<DEVICE_TYPE_GPU>(const real* inGrad,
real* outGrad,
const real* indices,
const TensorShape shape,
const FuncConfig& conf) {
real value = conf.get<real>("value");
int number = shape[0];
int channel = shape[1];
int height = shape[2];
int width = shape[3];
size_t nth = number * channel * height * width;
int blockSize = 1024;
int gridSize = (nth + blockSize - 1) / blockSize;
KeMulValueDiff<<<gridSize, blockSize, 0, STREAM_DEFAULT>>>(
inGrad, outGrad, indices, value, channel, height, width, nth);
CHECK_SYNC("MulValueGrad");
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "FunctionTest.h"
namespace paddle {
/*
for (size_t numSamples : {5, 32}) {
for (size_t channels : {5, 5, 32}) {
for (size_t imgSizeH : {5, 33, 100}) {
for (size_t imgSizeW : {5, 32, 96}) {
for (real value : {-0.5, 0.0, 0.5}) {
*/
TEST(MulValue, real) {
for (size_t numSamples : {5, 32}) {
for (size_t channels : {5, 5, 32}) {
for (size_t imgSizeH : {5, 33, 100}) {
for (size_t imgSizeW : {5, 32, 96}) {
for (real value : {-0.5, 0.0, 0.5}) {
for (bool firstHalf : {false, true}) {
VLOG(3) << " numSamples=" << numSamples
<< " channels=" << channels << " imgSizeH=" << imgSizeH
<< " imgSizeW=" << imgSizeW;
for (bool test_grad : {false}) {
CpuGpuFuncCompare compare(
test_grad ? "MulValueGrad" : "MulValue",
FuncConfig().set<real>("value", value));
TensorShape shape{numSamples, channels, imgSizeH, imgSizeW};
TensorShape indicesShape{numSamples, 6};
compare.addInputs(BufferArg(VALUE_TYPE_FLOAT, shape));
compare.addInputs(BufferArg(VALUE_TYPE_FLOAT, indicesShape));
compare.registerInitCallBack([=](BufferArg& arg, size_t index) {
if (index == 1) {
real* data = (real*)arg.data();
for (size_t i = 0; i < numSamples; ++i) {
size_t offset = i * 6;
data[offset] = firstHalf ? 1 : (int)channels / 2;
data[offset + 1] =
firstHalf ? (int)channels / 2 : channels;
data[offset + 2] = firstHalf ? 1 : (int)imgSizeH / 2;
data[offset + 3] =
firstHalf ? (int)imgSizeH / 2 : imgSizeH;
data[offset + 4] = firstHalf ? 1 : (int)imgSizeW / 2;
data[offset + 5] =
firstHalf ? (int)imgSizeW / 2 : imgSizeW;
}
}
});
compare.addOutputs(BufferArg(VALUE_TYPE_FLOAT,
shape,
test_grad ? ADD_TO : ASSIGN_TO),
test_grad ? ADD_TO : ASSIGN_TO);
compare.run();
}
}
}
}
}
}
}
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "MulValueLayer.h"
#include "paddle/utils/Stat.h"
namespace paddle {
REGISTER_LAYER(mul_value, MulValueLayer);
bool MulValueLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
Layer::init(layerMap, parameterMap);
CHECK_EQ(static_cast<int>(inputLayers_.size()), 2);
auto& conf = config_.inputs(0).mul_value_conf();
value_ = conf.value();
createFunction(forward_, "MulValue", FuncConfig().set("value", value_));
createFunction(backward_, "MulValueGrad", FuncConfig().set("value", value_));
return true;
}
void MulValueLayer::forward(PassType passType) {
Layer::forward(passType);
auto in0 = getInput(0);
imgH_ = in0.getFrameHeight();
imgW_ = in0.getFrameWidth();
if (imgH_ == 0 || imgW_ == 0) {
auto& conf = config_.inputs(0).mul_value_conf();
imgH_ = conf.image_conf().img_size_y();
imgW_ = conf.image_conf().img_size();
}
MatrixPtr imgV = in0.value;
size_t batchSize = imgV->getHeight();
size_t spatialSize = imgH_ * imgW_;
channelsNum_ = imgV->getWidth() / spatialSize;
shape_ = TensorShape({batchSize, channelsNum_, imgH_, imgW_});
resetOutput(batchSize, imgV->getWidth());
MatrixPtr indicesV = getInputValue(1);
indicesShape_ = TensorShape({batchSize, 6});
REGISTER_TIMER_INFO("MulValueForward", getName().c_str());
BufferArgs inArgs;
BufferArgs outArgs;
inArgs.addArg(*imgV, shape_);
inArgs.addArg(*indicesV, indicesShape_);
MatrixPtr outV = getOutputValue();
outArgs.addArg(*outV, shape_, ASSIGN_TO);
forward_[0]->calc(inArgs, outArgs);
}
void MulValueLayer::backward(const UpdateCallback& callback) {
REGISTER_TIMER_INFO("MulValueBackward", getName().c_str());
BufferArgs inArgs;
BufferArgs outArgs;
inArgs.addArg(*getOutputGrad(), shape_);
inArgs.addArg(*getInputValue(1), indicesShape_);
outArgs.addArg(*getInputGrad(0), shape_, ADD_TO);
backward_[0]->calc(inArgs, outArgs);
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "Layer.h"
namespace paddle {
/**
* \brief For each instance, this layer can be used to multiply a value to a
* specified sub continuous region. By providing start index and end
* index for C/H/W, you can specify the location and shape of the
* region.
*
* input_0: Input value.
* input_1: Indices value to specify the location an shape of the
* region.
*/
class MulValueLayer : public Layer {
public:
explicit MulValueLayer(const LayerConfig& config) : Layer(config) {}
~MulValueLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
protected:
TensorShape shape_;
TensorShape indicesShape_;
size_t imgH_;
size_t imgW_;
size_t channelsNum_;
real value_;
};
} // namespace paddle
......@@ -2358,6 +2358,37 @@ TEST(Layer, ScaleShiftLayer) {
}
}
TEST(Layer, MulValueLayer) {
const size_t batchSize = 64;
const size_t size = 4096;
TestConfig config;
config.layerConfig.set_type("mul_value");
config.inputDefs.push_back({INPUT_DATA, "input", size, 0});
MatrixPtr indicesV = Matrix::create(batchSize, 6, false, false);
auto* data = indicesV->getData();
for (size_t i = 0; i < batchSize; ++i) {
data[i * 2] = 2;
data[i * 2 + 1] = 4;
data[i * 2 + 2] = 16;
data[i * 2 + 3] = 32;
data[i * 2 + 4] = 16;
data[i * 2 + 5] = 32;
}
config.inputDefs.push_back({INPUT_SELF_DEFINE_DATA, "indices", indicesV, {}});
LayerInputConfig* input = config.layerConfig.add_inputs();
MulValueConfig* mulValueConf = input->mutable_mul_value_conf();
ImageConfig* imgConf = mulValueConf->mutable_image_conf();
imgConf->set_img_size(32);
imgConf->set_img_size_y(32);
imgConf->set_channels(4);
mulValueConf->set_value(1.0);
config.layerConfig.add_inputs();
for (auto useGpu : {false, true}) {
testLayerGrad(config, "mul_value", batchSize, false, useGpu, false);
}
}
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
initMain(argc, argv);
......
......@@ -169,7 +169,7 @@ void TensorCheck(AssertEq compare,
count++;
}
}
EXPECT_EQ(count, 0) << "There are " << count << " different element.";
EXPECT_EQ(count, 0) << "There are " << count << " different elements.";
}
template <typename AssertEq, typename Tensor1, typename Tensor2>
......
......@@ -321,6 +321,11 @@ message ClipConfig {
required double max = 2;
}
message MulValueConfig {
required ImageConfig image_conf = 1;
required float value = 2;
}
message LayerInputConfig {
required string input_layer_name = 1;
optional string input_parameter_name = 2;
......@@ -342,6 +347,7 @@ message LayerInputConfig {
optional MultiBoxLossConfig multibox_loss_conf = 16;
optional DetectionOutputConfig detection_output_conf = 17;
optional ClipConfig clip_conf = 18;
optional MulValueConfig mul_value_conf = 19;
}
message LayerConfig {
......
......@@ -3801,6 +3801,23 @@ class SwitchOrderLayer(LayerBase):
self.config.reshape_conf.width_axis.extend(reshape['width'])
@config_layer('mul_value')
class MulValueLayer(LayerBase):
def __init__(self, name, inputs, value, **xargs):
super(MulValueLayer, self).__init__(
name, 'mul_value', 0, inputs=inputs, **xargs)
mul_value_conf = self.config.inputs[0].mul_value_conf
mul_value_conf.value = value
# get channel, width and height from input_0 layer
input_layer = self.get_input_layer(0)
image_conf = mul_value_conf.image_conf
image_conf.img_size = input_layer.width
image_conf.img_size_y = input_layer.height
image_conf.channels = input_layer.size / (input_layer.width *
input_layer.height)
# Deprecated, use a new layer specific class instead
@config_func
def Layer(name, type, **xargs):
......
......@@ -144,6 +144,7 @@ __all__ = [
'img_conv3d_layer',
'resize_layer',
'sub_seq_layer',
'mul_value_layer',
]
......@@ -255,6 +256,8 @@ class LayerType(object):
RESIZE = 'resize'
SUB_SEQ_LAYER = 'subseq'
MUL_VALUE_LAYER = 'mul_value'
@staticmethod
def is_layer_type(type_name):
"""
......@@ -7037,3 +7040,50 @@ def sub_seq_layer(input, offsets, sizes, act=None, bias_attr=None, name=None):
LayerType.SUB_SEQ_LAYER,
parents=[input, offsets, sizes],
size=input.size)
@wrap_name_default('mul_value')
def mul_value_layer(input, indices, value, name=None):
"""
Given an image or feature map with CHW information, mul_value_layer can be
used to multiply a real value to values of a sub continuous region. You can
provide start and end indices of CHW for each instance. Please notice that
all start indices are counting from 1. The shape of indices should be
[batch_size, 6] and the layout for each row is [C_Start, C_End, H_Start,
H_End, W_Start, W_End].
.. code-block:: python
mul_value = mul_value_layer(input=input, indices=indices, value=value)
:param name: The name of this layer. It is optional.
:type name: basestring
:param input: The input of this layer which should contains CHW information.
:type input: LayerOutput
:param indices: Start index and end index for C H W, the input value should
be a 2-D matrix with shape [batch_size, 6].
:type indices: LayerOutput.
:param value: value to multiply.
:type value: float
:return: LayerOutput object.
:rtype: LayerOutput
"""
assert isinstance(input, LayerOutput), (
'The first input of mul_value_layer, must be a PaddlePaddle layer.')
assert isinstance(indices, LayerOutput), (
'The start and end indices for CHW, must be a PaddlePaddle layer.')
assert isinstance(value, float), (
'The value to multiply, must be a real value.')
Layer(
name=name,
type=LayerType.MUL_VALUE_LAYER,
inputs=[input.name, indices.name],
value=value)
return LayerOutput(
name,
LayerType.MUL_VALUE_LAYER,
parents=[input, indices],
size=input.size)
......@@ -10,6 +10,6 @@ test_prelu_layer test_row_conv test_detection_output_layer test_multibox_loss_la
test_recursive_topology test_gated_unit_layer test_clip_layer test_row_l2_norm_layer
test_kmax_seq_socre_layer test_sub_nested_seq_select_layer test_scale_shift_layer
test_seq_slice_layer test_cross_entropy_over_beam test_pooling3D_layer
test_conv3d_layer test_deconv3d_layer test_BatchNorm3D test_resize_layer)
test_conv3d_layer test_deconv3d_layer test_BatchNorm3D test_resize_layer test_mul_value_layer)
export whole_configs=(test_split_datasource)
type: "nn"
layers {
name: "data"
type: "data"
size: 2016
active_type: ""
height: 48
width: 42
}
layers {
name: "indices"
type: "data"
size: 6
active_type: ""
}
layers {
name: "__mul_value_0__"
type: "mul_value"
active_type: ""
inputs {
input_layer_name: "data"
mul_value_conf {
image_conf {
channels: 1
img_size: 42
img_size_y: 48
}
value: 0.0
}
}
inputs {
input_layer_name: "indices"
}
}
input_layer_names: "data"
input_layer_names: "indices"
output_layer_names: "__mul_value_0__"
sub_models {
name: "root"
layer_names: "data"
layer_names: "indices"
layer_names: "__mul_value_0__"
input_layer_names: "data"
input_layer_names: "indices"
output_layer_names: "__mul_value_0__"
is_recurrent_layer_group: false
}
from paddle.trainer_config_helpers import *
settings(batch_size=1000, learning_rate=1e-5)
data = data_layer(name='data', size=2016, height=48, width=42)
indices = data_layer(name='indices', size=6)
mul_value = mul_value_layer(input=data, indices=indices, value=0.0)
outputs(mul_value)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册