TensorFlow函数教程:tf.nn.embedding_lookup_sparse

2019-01-31 13:47 更新

tf.nn.embedding_lookup_sparse函数

tf.nn.embedding_lookup_sparse(
    params,
    sp_ids,
    sp_weights,
    partition_strategy='mod',
    name=None,
    combiner=None,
    max_norm=None
)

定义在:tensorflow/python/ops/embedding_ops.py.

请参阅指南:神经网络>Embeddings(嵌套)

计算给定 id 和 weight 的 embedding.

此操作假定由 sp_ids 表示的密集张量中的每一行至少有一个 id(即,没有具有空 feature 的行),并且 sp_ids 的所有 indice 都是规范的 row-major 顺序.

该函数还假设所有 id 值都在[0,p0]范围内,其中 p0 是沿着维度0的参数大小的总和.

参数:

  • params:表示完整的 embedding 张量的单张量,或除了第一维之外全部具有相同 shape 的 P 张量列表,表示切分的 embedding 张量.或者,一个 PartitionedVariable,通过沿维度0进行分区创建.对于给定的 partition_strategy,每个元素的大小必须适当.
  • sp_ids:int64 类型的 id 的 N x M SparseTensor(通常来自FeatureValueToId),其中 N 通常是批次大小,M 是任意的.
  • sp_weights:可以是具有 float/double weight 的 SparseTensor,或者是 None 以表示所有 weight 应为1.如果指定,则 sp_weights 必须具有与 sp_ids 完全相同的 shape 和 indice.
  • partition_strategy:指定切分策略的字符串,在 len(params) > 1 的情况下使用.目前支持两种切分方式:"div"和"mod",默认是"mod".查看tf.nn.embedding_lookup获取更多信息.
  • name:操作的可选名称.
  • combiner:指定 reduction 操作的字符串.目前支持“mean”,“sqrtn”和“sum”.“sum”计算每行的 embedding 结果的加权和.“mean”是加权和除以总 weight.“sqrtn”是加权和除以 weight 平方和的平方根.
  • max_norm:如果提供,则在组合之前将每个 embedding 规范化为具有等于 max_norm 的 l2 范数. 

返回:

表示稀疏 id 的组合 embedding 的密集张量.对于由 sp_ids 表示的密集张量中的每一行,操作查找该行中所有 id 的 embedding,将它们乘以相应的 weight,并按指定的方式组合这些 embedding.

换句话说,如果

shape(combined params) = [p0, p1, ..., pm]

并且:

shape(sp_ids) = shape(sp_weights) = [d0, d1, ..., dn]

然后:

shape(output) = [d0, d1, ..., dn-1, p1, ..., pm]

例如,如果 params 是一个 10x20 矩阵,则 sp_ids / sp_weights 是

[0, 0]: id 1, weight 2.0 [0, 1]: id 3, weight 0.5 [1, 0]: id 0, weight 1.0 [2, 3]: id 1, weight 3.0

如果 combiner=“mean”,那么输出将是3x20矩阵,其中:

output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5) output[1, :] = (params[0, :] * 1.0) / 1.0 output[2, :] = (params[1, :] * 3.0) / 3.0

可能引发的异常:

  • TypeError:如果 sp_ids 不是 SparseTensor,或者 sp_weights 既不是 None 也不是 SparseTensor.
  • ValueError:如果 combiner 不是 {“mean”,“sqrtn”,“sum”} 中的一个.
以上内容是否对您有帮助:
在线笔记
App下载
App下载

扫描二维码

下载编程狮App

公众号
微信公众号

编程狮公众号