Skip to content
章节导航

🕒 Published at: a few seconds ago
python
"""
@author:Nan J
@date  :2022/12/8:10:11
@IDE   :Spyder
"""
import os
import os.path
from osgeo import gdal
import sys
from osgeo import gdalconst
from osgeo import gdal
from osgeo import osr
import numpy as np
# coding=utf-8
from tqdm import tqdm


def WriteGTiffFile(filename, nRows, nCols, data, geotrans, proj, noDataValue, gdalType, nandID):  # 向磁盘写入结果文件
    print(filename)
    format = "GTiff"
    driver = gdal.GetDriverByName(format)
    ds = driver.Create(filename, nCols, nRows, 1, gdalType)
    ds.SetGeoTransform(geotrans)
    ds.SetProjection(proj)
    ds.GetRasterBand(nandID).SetNoDataValue(noDataValue)
    ds.GetRasterBand(nandID).WriteArray(data)
    ds = None


def File(bandID):  # 遍历文件,读取数据,算出均值
    rows, cols, geotransform, projection, noDataValue = Readxy('C://126//20211.tif')
    # 获取源文件的行,列,投影等信息,注意需要保持所有的源文件这些信息都是一致的
    print('rows and cols is ', rows, cols)
    #rows, cols = 200, 200
    #filesum = [[0.0] * cols] * rows  # 栅格值和,二维数组
    #average = [[0.0] * cols] * rows  # 存放平均值,二维数组,循环定义太慢
    filesum = np.zeros((rows, cols), dtype=np.float32)  # 转换类型为np.array
    average = np.zeros((rows, cols), dtype=np.float32)
    print('the type of filesum', type(filesum))
    count = 0
    rootdir = 'C://126'
    for dirpath, filename, filenames in os.walk(rootdir):  # 遍历源文件
        for filename in filenames:
            if os.path.splitext(filename)[1] == '.tif':  # 判断是否为tif格式
                filepath = os.path.join(dirpath, filename)
                purename = filename.replace('.tif', '')  # 获得除去扩展名的文件名,比如20211.tif,purename为20211
                #两种修改方式,均要修改RUN.write_img
                #1)路径里面的目录不要是rootdir
                #2)文件名不要是"2021_band_',用2021开头,随便起一个”save_2021_band_'“
                #逻辑是文件名的前4位是2021都要做累加
                if purename[:4] == '2021':  # 判断年份
                    # filedata = [[0.0] * cols] * rows  这也没有用到
                    # filedata = np.array(filedata)
                    print('start to read the data of :', filepath, ' and the band is ', bandID)		
                    filedata = Read(filepath, bandID)  # 将各年的多幅图像数据存入filedata中
                    count = count + 1
                    print('end to read the data of :', filepath, ' and the band is ', bandID)	
                    np.add(filesum, filedata, filesum)  # 求3幅图像相应栅格值的和
                    # print str(count)+'this is filedata',filedata
    print('count is ', count)
    for i in range(0, rows):
        for j in range(0, cols):
            # print(noDataValue, count)
            if (filesum[i, j] == noDataValue):  # 处理图像中的noData
                average[i, j] = -9999
            else:
                average[i, j] = filesum[i, j] * 1.0 / count  # 求平均
    RUN = GRID()
    RUN.write_img('C://126' + '//' + 'save2021_band_'+ str(bandID) + '.tif', projection, geotransform, average)
    #WriteGTiffFile("E://RemoteSensing//Image//meantest//2080_band"+ str(bandID) + ".tif", rows, cols, average,geotransform,projection, -9999, 10, bandID) #写入结果文件
    #return average

class GRID:

    # 读图像文件
    def read_img(self, filename):
        dataset = gdal.Open(filename)  # 打开文件
        im_width = dataset.RasterXSize  # 栅格矩阵的列数
        im_height = dataset.RasterYSize  # 栅格矩阵的行数
        im_geotrans = dataset.GetGeoTransform()  # 仿射矩阵
        im_proj = dataset.GetProjection()  # 地图投影信息
        im_data = dataset.ReadAsArray(0, 0, im_width, im_height)  # 将数据写成数组,对应栅格矩阵
        del dataset  # 关闭对象,文件dataset
        return im_proj, im_geotrans, im_data, im_width, im_height

    # 写文件,以写成tif为例
    def write_img(self, filename, im_proj, im_geotrans, im_data):
        # 判断栅格数据的数据类型
        if 'int8' in im_data.dtype.name:
            datatype = gdal.GDT_Byte
        elif 'int16' in im_data.dtype.name:
            datatype = gdal.GDT_UInt16
        else:
            datatype = gdal.GDT_Float32
        # 判读数组维数
        if len(im_data.shape) == 3:
            im_bands, im_height, im_width = im_data.shape
        else:
            im_bands, (im_height, im_width) = 1, im_data.shape
        # 创建文件
        driver = gdal.GetDriverByName("GTiff")  # 数据类型必须有,因为要计算需要多大内存空间
        dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)
        dataset.SetGeoTransform(im_geotrans)  # 写入仿射变换参数
        dataset.SetProjection(im_proj)  # 写入投影
        if im_bands == 1:
            dataset.GetRasterBand(1).WriteArray(im_data)  # 写入数组数据
        else:
            for i in range(im_bands):
                dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
        del dataset


def Readxy(RasterFile):  # 读取每个图像的信息
    ds = gdal.Open(RasterFile, gdal.GA_ReadOnly)
    if ds is None:
        print('Cannot open ', RasterFile)
        sys.exit(1)
    cols = ds.RasterXSize
    rows = ds.RasterYSize
    band = ds.GetRasterBand(1)
    #data = band.ReadAsArray(0, 0, cols, rows)   这一个也没有用到啊
    noDataValue = band.GetNoDataValue()
    projection = ds.GetProjection()
    geotransform = ds.GetGeoTransform()
    return rows, cols, geotransform, projection, noDataValue


def Read(RasterFile, bandID):  # 读取每个图像的信息
    ds = gdal.Open(RasterFile, gdal.GA_ReadOnly)
    if ds is None:
        print('Cannot open ', RasterFile)
        sys.exit(1)
    cols = ds.RasterXSize
    rows = ds.RasterYSize
    band = ds.GetRasterBand(bandID)
    data = band.ReadAsArray(0, 0, cols, rows)
    return data


if __name__ == "__main__":
    print("ok1")

    for i in tqdm(range(1, 20)):#20是波段数量加1,比如原始有19个波段,这里就填20.
        File(1)

    print("ok2")