# coding: utf-8


# coding: utf-8
# NAME: 
# FILE: Registration_for_omics_v7_from_SIS.py
# REVISION : 1.1.0 - 18/10/2022
# AUTHOR : Mariia Burdyniuk 
# Copyright(c) 2022 arivis AG, Germany. All Rights Reserved.
#
# Permission is granted to use, modify and distribute this code,
# as long as this copyright notice remains part of the code.
#
# PURPOSE: The script perform image registration based on the 'FIXED' DAPI staining in the first cycle.
# For each tile the new imageset will be created with the unchnaged FIXED imageset (cycle 0) and all registered channels.
# Optionally, the user the can remove all 'moving' DAPI channels

# Each file should contain '-XX. at the end of the file.
# The files will be sorted numerically in the ascending order and the first cycle will be used as the fixed imageset for registration.

# Please indicate the DAPI channel for each imageset in IMGESET_DAPI_CHANNEL

# Tested for V4D Release : 4.0.0 -- tested on : 2022_10-_18

# NOTE: 
#%% External Package Import 
import time
import sys
import arivis
import arivis_core as core
import itk
import numpy as np
from glob import glob

#%% User defined variables

RUN_SIS_IMPORT = False

REGISTRATION_METHOD = 'affine'
# choose from 'translation' 'rigid' 'affine'  'bspline'  'spline'  

ORIG_IMAGE_PATH = 'D:\\demo_test\\'
SIS_FILE_PATH = 'D:\\Reg_result\\'

TILE_NUMBER = 9

TWO_D = False

IMAGESET_DAPI_CHANNEL = [3, 1, 2] #, 4, 4, 4, 4, 4] # channel count starts from 1!

KEEP_DAPI_CHANNEL = True
# False will exclude all the DAPI channels from the resulting imageset

# Additional (Advanced) variables
MAX_NUMBER_OF_ITERATIONS = 512         # default 256  increases the computational time
MAX_NUMBER_OF_SAMPLING_ATTEMPTS = 16   # default 8    increases the computational time
NUMBER_OF_RESOLUTIONS = 6              # default 3    increases the computational time

RESOLUTIONS = 6             # for bspline transform only  default 3
GRID = 30                   # for bspline transform only default 60

#%%    
NEW_IMAGESET_NAME = '20231117_Reg_affine'  # name of the registered imageset 
OUTPUT_CH_APPEND = '_cycle_'             # name of the registered channel

DAPI  = '_DAPI'
TIMEPOINT = 0   # not for time-series data

#%% Function: GetPixelTypes()
def getPixelTypes(imageset): 
  if imageset is None :
    print( "No Image Set open" )
    return None    
  pixelType = imageset.get_pixeltype()
  if pixelType == core.ImageSet.PIXELTYPE_USHORT:
    typeP = 'H'    
    dt = np.dtype(np.ushort) 
    Coefficent = int(65535)   
  elif pixelType == core.ImageSet.PIXELTYPE_ULONG:
    typeP = 'L' 
    dt = np.dtype(np.uint)
    Coefficent = int(65535)
  elif pixelType == core.ImageSet.PIXELTYPE_FLOAT:
    typeP = 'f' 
    dt = np.dtype(np.float32)
    Coefficent = int(65535)
  else:     #core.ImageSet.PIXELTYPE_UCHAR:
    Coefficent = int(255)
    typeP = 'B'
    dt = np.dtype(np.uint8)
  return dt,typeP,Coefficent 

#%% Function: insert_result_image(result_image, fixed_imageset)
def insertResultImage(image, imageset, ch_name, typeP):
    # create a new channel to store the result image
    print('Writing a new channel')
    reg_channel = prepareChannel(imageset, ch_name)
    
    boundingbox4d = imageset.get_bounding_box()    
    boundingbox3d = core.Bounds3D()
    boundingbox3d.x1 = boundingbox4d.x1
    boundingbox3d.x2 = boundingbox4d.x2
    boundingbox3d.y1 = boundingbox4d.y1
    boundingbox3d.y2 = boundingbox4d.y2
    boundingbox3d.z1 = boundingbox4d.z1
    boundingbox3d.z2 = boundingbox4d.z2 
    
    TP = boundingbox4d.t1
    image = image.astype(typeP)
    imageset.write_imagedata(buffer = image, region=boundingbox3d, channel=reg_channel, timepoint=TP)

#%% Function: PrepareChannel
def prepareChannel(imageset,chname):
  channels = imageset.get_channel_count()
  imageset.insert_channels(channels, 1)
  if chname!="":
      lastCh = imageset.get_channel_count()
      chname__ = chname 
      imageset.set_channel_name(lastCh - 1, chname__)
  return channels 
       
#%% Function: registerImages
def registerImages(fixed_image, moving_image, parameter_object, Z_spacing_ratio):
    print('Registering two images')
    
    if TWO_D: 
        fixed_image = fixed_image.astype(np.float32)
        fixed_image = itk.image_view_from_array(fixed_image)

        moving_image = moving_image.astype(np.float32)
        moving_image = itk.image_view_from_array(moving_image)
    
        # Call registration function FIXED, MOVING...
        result_image, result_transform_parameters = itk.elastix_registration_method(fixed_image, moving_image, parameter_object=parameter_object, log_to_console=False)
        result_image = np.asarray(result_image).astype(np.float32)
    else:
         XY_spacing_ratio = 1.0
                        
         fixed_image = fixed_image.astype(np.float32)
         fixed_image = itk.image_view_from_array(fixed_image)
         fixed_image.SetSpacing([Z_spacing_ratio, XY_spacing_ratio, XY_spacing_ratio])
                
         moving_image = moving_image.astype(np.float32)
         moving_image = itk.image_view_from_array(moving_image)
         moving_image.SetSpacing([XY_spacing_ratio, XY_spacing_ratio, XY_spacing_ratio])
     
         # Call registration function FIXED, MOVING...
         result_image, result_transform_parameters = itk.elastix_registration_method(fixed_image, moving_image, parameter_object=parameter_object, log_to_console=False)
         result_image = np.asarray(result_image).astype(np.float32)
        
    return result_image, result_transform_parameters

#%% Function: SisFileImport
def SisFileImport(file_name):
    path_name_split = file_name.split('-')
    print('Found image: ', file_name)   
    file_name_split = path_name_split[-1].split('.')
    cycle = int(file_name_split[0])    
    print('Cycle number detected: ', cycle)  
    sis_file_name  =  SIS_FILE_PATH + 'Cycle_' + (file_name_split[0])  
    print('Importing image', file_name)
    arivis.Import.import_single_file(file_name, sis_file_name, True )   
    return  

#%% Function: prepare RegisteredImageset()
def prepareRegisteredImageset(fixed_imageset, tile, viewer, fixed_Plane_Number):
      
      fixed_imageset_name = fixed_imageset.get_name()
      imageset_name = fixed_imageset_name + '_tile_' + str(tile)
      im_name = SIS_FILE_PATH + NEW_IMAGESET_NAME
      
      BBox = fixed_imageset.get_bounding_box()
      
      bounds = core.Bounds2D
      bounds.x = BBox.x1
      bounds.width = BBox.x2 - BBox.x1 +1
      bounds.y = BBox.y1  
      bounds.height = BBox.y2 - BBox.y1  +1
      
      fixed_imageset.get_plane_count()
      planes = fixed_Plane_Number
      
      _, typeP,_ = getPixelTypes(fixed_imageset) 
      if typeP == 'H':
         pixel_type =  core.ImageSet.PIXELTYPE_USHORT
      elif typeP == 'B':
         pixel_type =  core.ImageSet.PIXELTYPE_UCHAR
      elif typeP == 'f':
         pixel_type =  core.ImageSet.PIXELTYPE_FLOAT
      elif typeP == 'L':
          pixel_type =  core.ImageSet.PIXELTYPE_ULONG
      
      document = viewer.get_document() 
      if document == None:
          viewer.create_document(im_name, pixel_type, bounds.width, bounds.height, planes, 1, 0)
          document = viewer.get_document() 
      else:
            if document.get_imageset(imageset_name) is not None:
                  print('Deleted imageset:', imageset_name)
                  document.delete_imageset(document.get_imageset(imageset_name))
       
      newImageset = document.create_imageset(imageset_name, pixel_type, 0) #zero  channels set - all channels created later
      newImageset.insert_timepoints(0, 1) # timepoint 0 to sum of TPs (1)
      newImageset.insert_planes(0, 0, planes, bounds) #timepoint first plane NfPlanes bounds
      print('created:', imageset_name)
      document.save() #debug
      return newImageset, imageset_name    
#%% Function: readOneImageStack
def readOneImageStack(imageset, tile, channel, Plane_Number):     
      dtype, _,_ = getPixelTypes(imageset)    
      BBox = imageset.get_bounding_box()
      bounds = core.Bounds3D()
      bounds.x1 = BBox.x1
      bounds.x2 = BBox.x2
      bounds.y1 = BBox.y1
      bounds.y2 = BBox.y2
      bounds.z1 = tile * Plane_Number 
      bounds.z2 = tile * Plane_Number + Plane_Number -1

      if TWO_D:
          StackBuffer = np.empty((BBox.y2 - BBox.y1 +1 , BBox.x2 - BBox.x1 +1), dtype)
      else:
          StackBuffer = np.empty((bounds.z2 - bounds.z1 +1, BBox.y2 - BBox.y1 +1 , BBox.x2 - BBox.x1 +1), dtype)

      imageset.read_imagedata(bounds, channel, TIMEPOINT, StackBuffer)
      return StackBuffer

#%% Function: transferAllChannels
def transferAllChannels(fixed_imageset, destination_imageset, tile_imageset_name, tile, Plane_Number):     
       channel_count = fixed_imageset.get_channel_count()
       for channel in range(0, channel_count):
           imageStack = readOneImageStack(fixed_imageset, tile, channel, Plane_Number)
           fixed_channel_name = fixed_imageset.get_channel_name(channel)
           ch_name = fixed_channel_name + OUTPUT_CH_APPEND + str(cycle_list[0])
           insertResultImage(imageStack, destination_imageset, ch_name, typeP)  
           
#%% Function: findImageset
def findImageset(cycle_searched):
    viewer_list = arivis.App.get_viewer_list()
    for viewer in viewer_list:
        document = viewer.get_document()
        if document != None:  
          imageset_list = document.get_imagesets()
          for imageset in imageset_list:
            imageset_name = imageset.get_name()
            path_name_split = imageset_name.split('-') 
            file_name_split = path_name_split[-1].split('.')
            if len(file_name_split) > 1:
                cycle_number = int(file_name_split[0])    
                if cycle_number == cycle_searched:
                  searched_imageset = imageset
    return searched_imageset

#%% Function: checkPlaneNumber
def checkPlaneNumber(imageset):
    imageset_name = imageset.get_name()
    plane_count = imageset.get_plane_count()
    if plane_count%TILE_NUMBER == 0:
       Plane_Number = int(plane_count/TILE_NUMBER)
       print(Plane_Number, ' planes in ', imageset_name) 
       pass 
    else:
        print(imageset_name)
        print   ('The number of planes does not correspond to the number of TILES given')  
        sys.exit('The number of planes does not correspond to the number of TILES given. Aborting the script.')
    return Plane_Number

# %% Main
# Helper to get execution time
startTime = time.time()
print ("Script is running ........ " )

#%% Check the image dimensions and number
if RUN_SIS_IMPORT:
    path = ORIG_IMAGE_PATH + '*.*'
    image_list = sorted(glob(path))
    print(len(image_list), ' (image) files found')
    if len(image_list) <= 1:
       sys.exit('Error: Not enough images for registration')
    if len(image_list) != len(IMAGESET_DAPI_CHANNEL):
       print('Error: The number of the given DAPI channels does not correspond to the number of images given.')
       print('Will use the first given DAPI channels.')
    if len(image_list) > len(IMAGESET_DAPI_CHANNEL):
       sys.exit('Error: The number of the given DAPI channels does not correspond to the number of images given. Please, check the folder with the raw data.')
   
#%% Run the SIS conversion and import one Cycle for testing
if RUN_SIS_IMPORT:
    file_name = image_list[0]
    SisFileImport(file_name)
    test_viewer = arivis.App.get_active_viewer()
    testImageset = test_viewer.get_imageset()
    testPlaneCount = testImageset.get_plane_count()
    for file_name in image_list[1:]:
        SisFileImport(file_name)

#%%  Find the first cycle imageset
cycle_list = []
viewer_list = arivis.App.get_viewer_list()
for viewer in viewer_list:
    document = viewer.get_document()
    if document != None:  
      imageset_list = document.get_imagesets()
      for imageset in imageset_list:
        imageset_name = imageset.get_name()
        path_name_split = imageset_name.split('-') 
        file_name_split = path_name_split[-1].split('.')
        cyc = file_name_split[0]
        if cyc == str:
           print('The Registered SIS file might alreayd exist. Please remove it from the SIS folder.')  
        else:     
            cycle_number = int(file_name_split[0])    
            print('Cycle number detected: ', file_name_split[0]) 
            cycle_list.append(cycle_number)
 
cycle_list.sort() 
print('Cycles detected: ', len(cycle_list))           
print('First cycle detected: ', cycle_list[0], '. This will be the FIXED image')

if len(cycle_list) > len(IMAGESET_DAPI_CHANNEL):
    sys.exit('The number of the given DAPI channels does not correspond to the number of images given. Aborting the script.')

#%% Check the number of planes in the fixed imageset
fixed_imageset = findImageset(cycle_list[0])
_, typeP,_ = getPixelTypes(fixed_imageset)
fixed_Plane_Number = checkPlaneNumber(fixed_imageset)
fixed_pixel_size = fixed_imageset.get_pixel_size()

#%%
arivis.App.open_window('VIEWER_2D')
dest_viewer = arivis.App.get_active_viewer()

#%% Prepare the registration parameter object:
MaximumNumberOfIterations = str(MAX_NUMBER_OF_ITERATIONS)      
MaximumNumberOfSamplingAttempts = str(MAX_NUMBER_OF_SAMPLING_ATTEMPTS)   
NumberOfResolutions = str(NUMBER_OF_RESOLUTIONS)  
parameter_object = itk.ParameterObject.New()
if REGISTRATION_METHOD == 'bspline':
   parameter_map = parameter_object.GetDefaultParameterMap(REGISTRATION_METHOD, RESOLUTIONS, GRID)
else:
   parameter_map = parameter_object.GetDefaultParameterMap(REGISTRATION_METHOD, RESOLUTIONS)   
               
parameter_object.AddParameterMap(parameter_map)
parameter_object.SetParameter("MaximumNumberOfIterations", MaximumNumberOfIterations)
parameter_object.SetParameter("MaximumNumberOfSamplingAttempts", MaximumNumberOfSamplingAttempts)
parameter_object.SetParameter("NumberOfResolutions", NumberOfResolutions)
#print(parameter_object)

#%% Prepare a new imageset for each tile for the registered imageset
for tile in range(0, TILE_NUMBER):  
     print('Processing FIXED imageset tile number ', tile)      
     destination_imageset, tile_imageset_name = prepareRegisteredImageset(fixed_imageset, tile, dest_viewer, fixed_Plane_Number) 
     transferAllChannels(fixed_imageset, destination_imageset, tile_imageset_name, tile, fixed_Plane_Number) 
     fixed_image = readOneImageStack(fixed_imageset, tile, IMAGESET_DAPI_CHANNEL[0]-1, fixed_Plane_Number) #channel count starts from zero

     # for each moving cycle
     for cycle in range(1, len(cycle_list)):
        print('Processing moving imageset, cycle ', cycle_list[cycle])   
        moving_imageset = findImageset(cycle_list[cycle])
        moving_Plane_Number = checkPlaneNumber(moving_imageset)
        moving_pixel_size = moving_imageset.get_pixel_size()
        Z_spacing_ratio = fixed_pixel_size[0] / moving_pixel_size[0]
        
        moving_image = readOneImageStack(moving_imageset, tile, IMAGESET_DAPI_CHANNEL[cycle]-1, moving_Plane_Number) #channel count starts from zero
        moving_dapi_name = moving_imageset.get_channel_name(IMAGESET_DAPI_CHANNEL[cycle]-1) #channel count starts from zero           
        ch_name = moving_dapi_name + OUTPUT_CH_APPEND + str(cycle_list[cycle]) + DAPI
        if KEEP_DAPI_CHANNEL:
            result_image, result_transform_parameters = registerImages(fixed_image, moving_image, parameter_object, Z_spacing_ratio)
            result_image[result_image<0]=0

            insertResultImage(result_image, destination_imageset, ch_name, typeP) 
        else:
            print('Passing the DAPI channel', ch_name) 
            _, result_transform_parameters = registerImages(fixed_image, moving_image, parameter_object, Z_spacing_ratio)
             
        channel_number = moving_imageset.get_channel_count()    
        for channel in range(0, channel_number):
            if channel == IMAGESET_DAPI_CHANNEL[cycle]-1: #channel count starts from zero  
                pass
            else: 
                print('Processing the staining channel', channel)  
                trans_image = readOneImageStack(moving_imageset, tile, channel, moving_Plane_Number)   
                trans_image = trans_image.astype(np.float32)
                # Z_spacing_ratio
                exp_image_regged = itk.transformix_filter(trans_image, result_transform_parameters, log_to_console = False)
                exp_image_regged = np.asarray(exp_image_regged).astype(np.float32)

                exp_image_regged[exp_image_regged<0]=0
                moving_dapi_name = moving_imageset.get_channel_name(channel)            
                ch_name = moving_dapi_name + OUTPUT_CH_APPEND + str(cycle_list[cycle]) 
 
                insertResultImage(exp_image_regged, destination_imageset, ch_name, typeP) 

#%% remove the empty default imageset
dest_document = dest_viewer.get_document()       
defaultImageset = dest_document.get_default_imageset()
name = defaultImageset.get_name()
if name == '':
  dest_document.delete_imageset(defaultImageset)
          
#%%         
endTime = time.time()
print ("script time: " + str(endTime - startTime))







