HybridSN代码修改

发布时间:2022-06-21 发布网站:脚本宝典
脚本宝典收集整理的这篇文章主要介绍了HybridSN代码修改脚本宝典觉得挺不错的,现在分享给大家,也给大家做个参考。

研究小组里初学习深度学习的同学我都布置写过 HybridSN 的代码:https://github.com/OUCTheoryGroup/colab_demo/blob/master/202003_models/HybridSN_GRSL2020.ipynb

最近做SRDP的同学反映,跑 Pavia 数据集的时候内存会爆,主要原因是 createImageCubes 这个函数有个地方:

patchesData = np.zeros([ width*height, windowSize, windowsSize, spectral_num])

因为 Pavia 数据集尺寸较大,width*height 就比较大了,内存会爆掉。

其实图像中大部分为是0,没有label,我们要取的,只是有label的部分。现在我改了改,先做个循环,看看有多少个像素有 label,然后记录在 count 里,分配内存时:

patchesData = np.zeros([count, windowSize, windowSize, spectral_num])

这样 count 比以前的 width*height 要小很多,内存就不会爆了

修改后的代码如下,供感兴趣的同学参考(Github上的我就不改了,留给以后新同学排雷):

# 在每个像素周围提取 patch 
def createImageCubes(X, y, windowSize=5, removeZeroLabels = True):
    # 给 X 做 padding
    margin = int((windowSize - 1) / 2)
    zeroPaddedX = padWithZeros(X, margin=margin)
    # 获得 y 中的标记样本数
    count = 0
    for r in range(0, y.shape[0]):
        for c in range(0, y.shape[1]):
            if y[r, c] != 0:
                count = count+1

    # split patches
    patchesData = np.zeros([count, windowSize, windowSize, X.shape[2]])
    patchesLabels = np.zeros(count)

    count = 0
    for r in range(margin, zeroPaddedX.shape[0] - margin):
        for c in range(margin, zeroPaddedX.shape[1] - margin):
            if y[r-margin, c-margin] != 0:
                patch = zeroPaddedX[r - margin:r + margin + 1, c - margin:c + margin + 1]   
                patchesData[count, :, :, :] = patch
                patchesLabels[count] = y[r-margin, c-margin]
                count = count + 1

    return patchesData, patchesLabels

脚本宝典总结

以上是脚本宝典为你收集整理的HybridSN代码修改全部内容,希望文章能够帮你解决HybridSN代码修改所遇到的问题。

如果觉得脚本宝典网站内容还不错,欢迎将脚本宝典推荐好友。

本图文内容来源于网友网络收集整理提供,作为学习参考使用,版权属于原作者。
如您有任何意见或建议可联系处理。小编QQ:384754419,请注明来意。
标签: