diff --git a/paddle/fluid/operators/one_hot_op.cc b/paddle/fluid/operators/one_hot_op.cc index 626895f49d8d4347f1e9a40526943cf00c73d034..cbb0c4028b3daa927529456e76253d93857a58b5 100644 --- a/paddle/fluid/operators/one_hot_op.cc +++ b/paddle/fluid/operators/one_hot_op.cc @@ -13,6 +13,8 @@ // limitations under the License. #include "paddle/fluid/operators/one_hot_op.h" +#include +#include #include "paddle/fluid/framework/framework.pb.h" namespace paddle { @@ -34,15 +36,34 @@ class OneHotOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_GE(x_dims[x_dims.size() - 1], 1U, "Last dimension of Input(X) should be 1."); } - int depth = ctx->Attrs().Get("depth"); - - PADDLE_ENFORCE_GT(depth, 0, "Should provide a positive depth (%d).", depth); framework::DDim out_dims(x_dims); + int depth = ctx->Attrs().Get("depth"); + if (ctx->HasInput("depth_tensor")) { + depth = -1; + } + out_dims[out_dims.size() - 1] = depth; ctx->SetOutputDim("Out", out_dims); ctx->ShareLoD("X", /* --> */ "Out"); } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); + } + + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const override { + if (var_name == "depth_tensor") { + return expected_kernel_type; + } + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } }; class OneHotOpMaker : public framework::OpProtoAndCheckerMaker { @@ -52,11 +73,15 @@ class OneHotOpMaker : public framework::OpProtoAndCheckerMaker { "(LoDTensor, LoDTensor) Input variable with rank at least 2. " "The last dimension of X should be 1. Each value of X is an index " "to indicate the position."); + AddInput("depth_tensor", "(Tensor, Tensor), Length of one-hot vector") + .AsDispensable(); AddOutput("Out", "(Tensor, Tensor) Output tensor with same rank as X. " "The tensor consists of one-hot representations of values in X."); + AddAttr("depth", - "A positive integer to specify the length of one-hot vector."); + "A positive integer to specify the length of one-hot vector.") + .SetDefault(-1); AddAttr("dtype", "An integer to specify the data type of one-hot " "vector. The default value is FP32.") diff --git a/paddle/fluid/operators/one_hot_op.cu b/paddle/fluid/operators/one_hot_op.cu index 59d8b9b8a8d554eb16826712ff634eed5df2d648..b9fe0bf2e9dc46ecc3974455e1328d8a83bcf388 100644 --- a/paddle/fluid/operators/one_hot_op.cu +++ b/paddle/fluid/operators/one_hot_op.cu @@ -62,8 +62,25 @@ class OneHotCUDAKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { auto* in = context.Input("X"); auto* out = context.Output("Out"); - int depth = context.Attr("depth"); + int depth = -1; + if (context.HasInput("depth_tensor")) { + auto* depth_tensor = context.Input("depth_tensor"); + if (platform::is_gpu_place(depth_tensor->place())) { + framework::Tensor temp; + TensorCopySync(*depth_tensor, platform::CPUPlace(), &temp); + depth = *temp.data(); + } else { + depth = *depth_tensor->data(); + } + + auto in_dims = in->dims(); + framework::DDim out_dims(in_dims); + out_dims[out_dims.size() - 1] = depth; + out->Resize(out_dims); + } else { + depth = context.Attr("depth"); + } framework::VisitDataType( static_cast( context.Attr("dtype")), diff --git a/paddle/fluid/operators/one_hot_op.h b/paddle/fluid/operators/one_hot_op.h index 1ebd2676496940ff8f90caaaded5c8227bd7ae78..7273080927ecd9b35d72c272e2d8b4254a0c3991 100644 --- a/paddle/fluid/operators/one_hot_op.h +++ b/paddle/fluid/operators/one_hot_op.h @@ -49,6 +49,7 @@ struct OneHotOpFunctor { }; using LoDTensor = framework::LoDTensor; +using Tensor = framework::Tensor; template class OneHotKernel : public framework::OpKernel { public: @@ -56,6 +57,15 @@ class OneHotKernel : public framework::OpKernel { auto* in = context.Input("X"); auto* out = context.Output("Out"); int depth = context.Attr("depth"); + if (context.HasInput("depth_tensor")) { + auto* depth_tensor = context.Input("depth_tensor"); + auto* depth_data = depth_tensor->data(); + depth = depth_data[0]; + auto in_dims = in->dims(); + framework::DDim out_dims(in_dims); + out_dims[out_dims.size() - 1] = depth; + out->Resize(out_dims); + } framework::VisitDataType( static_cast( diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 0fe554bfa493e25d08cf21b175110c13d2720dee..ab2d1d4049709c80ed3c4bea682155d8462b2750 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1013,7 +1013,7 @@ class Operator(object): return if type is None: raise ValueError( - "`type` to initilized an Operator can not be None.") + "`type` to initialized an Operator can not be None.") else: callstack_var_name = op_maker.kOpCreationCallstackAttrName() op_attrs[callstack_var_name] = list( @@ -1036,7 +1036,6 @@ class Operator(object): found = find_name(inputs, in_proto.name) assert found or in_proto.dispensable, "Input {} not found".format( in_proto.name) - if found: in_args = inputs[in_proto.name] if not isinstance(in_args, list): @@ -1046,13 +1045,17 @@ class Operator(object): "Input %s expects only one input, but %d are given." % (in_proto.name, len(in_args))) in_arg_names = [] - for arg in in_args: + for index, arg in enumerate(in_args): if isinstance(arg, six.string_types): in_arg_names.append(arg) elif isinstance(arg, six.binary_type): in_arg_names.append(arg.decode()) - else: + elif isinstance(arg, Variable): in_arg_names.append(cpt.to_text(arg.name)) + else: + raise ValueError( + "not suprt args type , should be[ string_type, binary_type, Varibale]" + ) self.desc.set_input(in_proto.name, in_arg_names) else: self.desc.set_input(in_proto.name, []) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 0718235fc0d86f6c0198c4e69b5cfa13d0418227..447bd719233152f2a3f58fe97f4d566a01a765d6 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -6564,11 +6564,24 @@ def one_hot(input, depth): one_hot_label = fluid.layers.one_hot(input=label, depth=10) """ helper = LayerHelper("one_hot", **locals()) + one_hot_out = helper.create_variable_for_type_inference(dtype='float32') + + if in_dygraph_mode(): + inputs = {'X': input} + attrs = {'depth': depth} + else: + if not isinstance(depth, Variable): + # user attribute + inputs = {'X': input} + attrs = {'depth': depth} + else: + inputs = {'X': input, 'depth_tensor': depth} + attrs = {} helper.append_op( type="one_hot", - inputs={'X': input}, - attrs={'depth': depth}, + inputs=inputs, + attrs=attrs, outputs={'Out': one_hot_out}, stop_gradient=True) return one_hot_out diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 4136bb7fef1054ee5698e05bb297e25e20317cf4..9217a3fc446c7999a5f20ea33b12f05795e7e710 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -1267,6 +1267,12 @@ class TestBook(LayerTest): out = layers.scatter(input=x, index=idx, updates=updates) return (out) + def make_one_hot(self): + with fluid.framework._dygraph_place_guard(place=fluid.CPUPlace()): + label = self._get_data(name="label", shape=[1], dtype="int32") + one_hot_label = layers.one_hot(input=label, depth=10) + return (one_hot_label) + def make_label_smooth(self): # TODO(minqiyang): support gpu ut self._force_to_use_cpu = True diff --git a/python/paddle/fluid/tests/unittests/test_one_hot_op.py b/python/paddle/fluid/tests/unittests/test_one_hot_op.py index 7afdae804a65b9fb05a521a1b08ce0bfb21d721f..f213a0c77f4babdb46626c6e7d9b631a4e79a631 100644 --- a/python/paddle/fluid/tests/unittests/test_one_hot_op.py +++ b/python/paddle/fluid/tests/unittests/test_one_hot_op.py @@ -28,10 +28,34 @@ class TestOneHotOp(OpTest): def setUp(self): self.op_type = 'one_hot' depth = 10 + depth_np = np.array(10).astype('int32') dimension = 12 x_lod = [[4, 1, 3, 3]] x = [np.random.randint(0, depth - 1) for i in range(sum(x_lod[0]))] - x = np.array(x).astype('int').reshape([sum(x_lod[0]), 1]) + x = np.array(x).astype('int32').reshape([sum(x_lod[0]), 1]) + + out = np.zeros(shape=(np.product(x.shape[:-1]), + depth)).astype('float32') + + for i in range(np.product(x.shape)): + out[i, x[i]] = 1.0 + + self.inputs = {'X': (x, x_lod), 'depth_tensor': depth_np} + self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)} + self.outputs = {'Out': (out, x_lod)} + + def test_check_output(self): + self.check_output() + + +class TestOneHotOp_attr(OpTest): + def setUp(self): + self.op_type = 'one_hot' + depth = 10 + dimension = 12 + x_lod = [[4, 1, 3, 3]] + x = [np.random.randint(0, depth - 1) for i in range(sum(x_lod[0]))] + x = np.array(x).astype('int32').reshape([sum(x_lod[0]), 1]) out = np.zeros(shape=(np.product(x.shape[:-1]), depth)).astype('float32') @@ -40,7 +64,7 @@ class TestOneHotOp(OpTest): out[i, x[i]] = 1.0 self.inputs = {'X': (x, x_lod)} - self.attrs = {'depth': depth, 'dtype': int(core.VarDesc.VarType.FP32)} + self.attrs = {'dtype': int(core.VarDesc.VarType.FP32), 'depth': depth} self.outputs = {'Out': (out, x_lod)} def test_check_output(self): @@ -48,13 +72,37 @@ class TestOneHotOp(OpTest): class TestOneHotOp_default_dtype(OpTest): + def setUp(self): + self.op_type = 'one_hot' + depth = 10 + depth_np = np.array(10).astype('int32') + dimension = 12 + x_lod = [[4, 1, 3, 3]] + x = [np.random.randint(0, depth - 1) for i in range(sum(x_lod[0]))] + x = np.array(x).astype('int32').reshape([sum(x_lod[0]), 1]) + + out = np.zeros(shape=(np.product(x.shape[:-1]), + depth)).astype('float32') + + for i in range(np.product(x.shape)): + out[i, x[i]] = 1.0 + + self.inputs = {'X': (x, x_lod), 'depth_tensor': depth_np} + self.attrs = {} + self.outputs = {'Out': (out, x_lod)} + + def test_check_output(self): + self.check_output() + + +class TestOneHotOp_default_dtype_attr(OpTest): def setUp(self): self.op_type = 'one_hot' depth = 10 dimension = 12 x_lod = [[4, 1, 3, 3]] x = [np.random.randint(0, depth - 1) for i in range(sum(x_lod[0]))] - x = np.array(x).astype('int').reshape([sum(x_lod[0]), 1]) + x = np.array(x).astype('int32').reshape([sum(x_lod[0]), 1]) out = np.zeros(shape=(np.product(x.shape[:-1]), depth)).astype('float32') diff --git a/python/paddle/fluid/tests/unittests/test_operator_desc.py b/python/paddle/fluid/tests/unittests/test_operator_desc.py index 37b9a9188ab44df81029ae6d9925ae21c1929cff..aa9634a2d419cbe791b42af526fe2e2bc37a5727 100644 --- a/python/paddle/fluid/tests/unittests/test_operator_desc.py +++ b/python/paddle/fluid/tests/unittests/test_operator_desc.py @@ -33,7 +33,7 @@ class TestOperator(unittest.TestCase): except ValueError as v_err: self.assertEqual( cpt.get_exception_message(v_err), - "`type` to initilized an Operator can not be None.") + "`type` to initialized an Operator can not be None.") try: block.append_op(type="no_such_op") self.assertFail() diff --git a/python/paddle/fluid/tests/unittests/test_recordio_reader.py b/python/paddle/fluid/tests/unittests/test_recordio_reader.py index f5009556adc8951aad80532d77cac4b920887c66..0417da7228e96ed8daffa7bbfcb7c12358cd78ec 100644 --- a/python/paddle/fluid/tests/unittests/test_recordio_reader.py +++ b/python/paddle/fluid/tests/unittests/test_recordio_reader.py @@ -86,3 +86,7 @@ class TestRecordIO(unittest.TestCase): def test_double_buffer_reader(self): self.test_main(decorator_callback=lambda reader: fluid.layers.io.double_buffer(reader, place='cuda:0' if fluid.core.is_compiled_with_cuda() else 'cpu')) + + +if __name__ == '__main__': + unittest.main()