TensorFlow函数教程:tf.nn.pool

由 Carrie 创建, 最后一次修改 2019-01-28

tf.nn.pool函数

tf.nn.pool(
    input,
    window_shape,
    pooling_type,
    padding,
    dilation_rate=None,
    strides=None,
    name=None,
    data_format=None
)

定义在:tensorflow/python/ops/nn_ops.py.

请参阅指南:神经网络>池操作

执行N-D池操作.

在data_format不以“NC”开头的情况下,计算0 <= b <batch_size,0 <= x [i] <output_spatial_shape [i],0 <= c <num_channels:

output[b, x[0], ..., x[N-1], c] =
  REDUCE_{z[0], ..., z[N-1]}
    input[b,
          x[0] * strides[0] - pad_before[0] + dilation_rate[0]*z[0],
          ...
          x[N-1]*strides[N-1] - pad_before[N-1] + dilation_rate[N-1]*z[N-1],
          c],

其中,还原函数REDUCE取决于pooling_type的值,并且pad_before是根据此处注释中描述的padding的值定义的.减少从不包括越界位置.

在data_format以“NC”开头的情况下,输入和输出简单地转置如下:

pool(input, data_format, **kwargs) =
  tf.transpose(pool(tf.transpose(input, [0] + range(2,N+2) + [1]),
                    **kwargs),
               [0, N+1] + range(1, N+1))

参数:

  • input:秩为N + 2的Tensor,如果data_format不以“NC”(默认),则shape为[batch_size] + input_spatial_shape + [num_channels];如果data_format以“NC”开头,则shape为[batch_size, num_channels] + input_spatial_shape.池化仅在空间维度上发生.
  • window_shape:N个int> = 1的序列.
  • pooling_type:指定池操作,必须是“AVG”或“MAX”.
  • padding:填充算法必须为“SAME”或“VALID”.
  • dilation_rate: 可选.扩张速度.N个int> = 1的列表.默认为[1] * N.如果dilation_rate的任何值> 1,则步幅的所有值必须为1.
  • strides: 可选.N个int> = 1的序列.默认为[1] * N.如果步幅的任何值> 1,则dilation_rate的所有值必须为1.
  • name: 可选.操作的名称.
  • data_format:string或None.指定input和输出的通道维度是最后一个维度(默认,或者data_format不是以“NC”开头),还是第二个维度(如果data_format以“NC”开头).对于N = 1,有效值为“NWC”(默认)和“NCW”.对于N = 2,有效值是“NHWC”(默认)和“NCHW”.对于N = 3,有效值为“NDHWC”(默认)和“NCDHW”.

返回:

秩为N + 2的张量,如果data_format为None或者不以“NC”开头,则shape为[batch_size] + output_spatial_shape + [num_channels],或者如果data_format以“NC”开头,则shape为

[batch_size,num_channels] + output_spatial_shape,其中output_spatial_shape取决于填充的值:

  • 如果padding =“SAME”:output_spatial_shape [i] = ceil(input_spatial_shape [i] / strides [i])
  • 如果padding =“VALID”:output_spatial_shape [i] = ceil((input_spatial_shape [i] - (window_shape [i] - 1)* dilation_rate [i])/ strides [i])

可能引发的异常:

  • ValueError:如果参数无效.
以上内容是否对您有帮助:

二维码
建议反馈
二维码