TensorFlow分区索引的使用

由 Carrie 创建, 最后一次修改 2017-10-10

tf.dynamic_partition

dynamic_partition(
    data,
    partitions,
    num_partitions,
    name=None
)

参见指南:张量变换>分割和连接

使用分区中的索引将数据分成 num_partitions 的张量.

对于大小为 partitions.ndim 的每个索引元组 js,切片数据为 [js, ...] ,成为 outputs[partitions[js]] 其中的一部分.这些 partitions[js] = i 切片以 js 的词典顺序被放置在 outputs[i] 中,outputs[i] 的第一个维度是分区中的条目数等于 i.详细参考如下:

outputs[i].shape = [sum(partitions == i)] + data.shape[partitions.ndim:]

outputs[i] = pack([data[js, ...] for js if partitions[js] == i])

data.shape 必须和 partitions.shape 一起启动.

例如:

# Scalar partitions.
partitions = 1
num_partitions = 2
data = [10, 20]
outputs[0] = []  # Empty with shape [0, 2]
outputs[1] = [[10, 20]]

# Vector partitions.
partitions = [0, 0, 1, 1, 0]
num_partitions = 2
data = [10, 20, 30, 40, 50]
outputs[0] = [10, 20, 50]
outputs[1] = [30, 40]

有关 dynamic_stitch 如何将分区合并回来的示例.

TensorFlow合并分区的实例

ARGS:

  • data:一个 Tensor.
  • partitions:一个 int32 类型的张量.可以是任何形状.索引在范围 [0, num_partitions) 内.
  • num_partitions:一个大于等于1的整数;要输出的分区数.
  • name:操作的名称(可选).

返回:

与数据具有相同类型的 num_partitions 张量对象的列表.

以上内容是否对您有帮助:
二维码
建议反馈
二维码