# coding: utf-8
#
# NAME: CellPose Segmenter
# FILE: CellPose_Segmenter_RevA                           
# REVISION : 1.4.1 - 2022-05-14
# AUTHOR : Maurizio Abbate / Mariia Burdyniuk
#          Giovanni Cardone (Biochem-MPG)
# Copyright(c) 2021 arivis AG, Germany. All Rights Reserved.
#
# Permission is granted to use, modify and distribute this code,
# as long as this copyright notice remains part of the code.
# PURPOSE : 
# Tested for V4d Release : 3.6
# NOTE: Tiling has been added to face huge dataset 
# ------------------------------ External Package Import ----------------------
import arivis_parameter, arivis_operation
import arivis_core as core
import numpy as np
from cellpose import models
# import cv2
# from arivis_objects import Polygon
# from arivis_core import Point2D
# import arivis_objects as objects


@arivis_parameter.add_arivis_parameters(Model_Name = 'cyto2', \
                                        Input_channel = 1,\
                                        Second_channel = 0,\
                                        Diameter_in_um = 3.,\
                                        Flow_threshold = arivis_parameter.param_range(0.4,0.01,1), \
                                        Cellprob_threshold = arivis_parameter.param_range(0,-6,6),\
                                        Fast_3D=True,
                                        Custom_model = False,
                                        Custom_model_path = 'C:/' ,\
                                        Tile_Size = arivis_parameter.param_range(4096,1024,65536))
@arivis_parameter.add_param_description(Model_Name = 'Choose from {cyto, cyto2, nuclei}',\
                                        Input_channel = "Main prediction channel",\
                                        Second_channel= 'Additional channel for the prediction (e.g. nuclei for Cyto2 model)',\
                                        Diameter_in_um = 'Estimated diameter in microns (0:auto estimation)',\
                                        Flow_threshold = 'Consistency with model (0...1)',\
                                        Cellprob_threshold = 'Mask inclusion threshold (-6...+6)',\
                                        Fast_3D='Analyse slices separately',\
                                        Custom_model = 'Do you have your own model?',\
                                        Custom_model_path = 'Model path including hte file name' ,\
                                        Tile_Size = "Max tile size (XY - pizels)"    ) 

def main(Model_Name, Input_channel, Second_channel, Diameter_in_um, Flow_threshold, Cellprob_threshold, Fast_3D, Custom_model, Custom_model_path,Tile_Size):
  # ---------------------------------------------------------------------------
  # make tensorflow output silent:
  # ---------------------------------------------------------------------------      
  import os
  os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # or any {'0', '1', '2'}
  # ---------------------------------------------------------------------------
  # suppress syntax warning from tensorflow 
  # ---------------------------------------------------------------------------
  import warnings
  warnings.filterwarnings("ignore", category=SyntaxWarning)  
  np.random.seed(6)
  # configure model 
  if not Custom_model:
    model = models.Cellpose(gpu=True, model_type=Model_Name)
  else:   
    model = models.CellposeModel(gpu=True, pretrained_model=Custom_model_path, net_avg=False)
  # ---------------------------------------------------------------------------  
  #check if the  input is two channel or one:  
  # ---------------------------------------------------------------------------      
  if Second_channel >= 1: channels_CP = [1,2]
  else:                   channels_CP = [1,0] 
  Input_channel  -= 1     # change to zero-based indexing
  Second_channel -= 1
  context = arivis_operation.Operation.get_context()
  input_data = context.get_input()
  output_data = context.get_output()
  input_bounds = input_data.get_bounds() #bounding box 4D
  pixel_type = input_data.get_pixel_type()
  pixel_size = input_data.get_pixel_size()
  # ---------------------------------------------------------------------------
  # check the diameter:
  # ---------------------------------------------------------------------------
  Diameter_in_pixels  = int (Diameter_in_um / pixel_size.x)   
  if Diameter_in_um == 0 or Diameter_in_pixels == 0:
      diameter = None
  else:
      diameter = Diameter_in_pixels  
  number_of_channels = input_data.get_channel_count()
  if(number_of_channels < 1):
    print("Error: this script needs at least one channel. End of script.")
  number_of_timepoints = (input_bounds.t2 - input_bounds.t1 + 1)
  number_of_planes = (input_bounds.z2 - input_bounds.z1 + 1)  
  # ---------------------------------------------------------------------------       
  totalWork = number_of_timepoints
  workDone = 0
  # ---------------------------------------------------------------------------
  tile = core.Bounds3D()
  tile.x1 = input_bounds.x1
  tile.x2 = input_bounds.x2
  tile.y1 = input_bounds.y1
  tile.y2 = input_bounds.y2
  tile.z1 = input_bounds.z1
  tile.z2 = input_bounds.z2
  # ---------------------------------------------------------------------------
  # check the anisotropy:
  # ---------------------------------------------------------------------------      
  if number_of_planes > 1:
      process_3D = not Fast_3D
      stitch_thr = 0.2
      anisotropy = pixel_size.z/pixel_size.x   # assuming same x and y    
  else:
      process_3D = False
      stitch_thr = 0
      anisotropy = 1.0      
  # ---------------------------------------------------------------------------
  # get tiles   
  # ---------------------------------------------------------------------------
  tiles=[]
  tiles.append(tile)
  if Tile_Size < max(tile.x2-tile.x1+1,tile.y2-tile.y1+1):  
    tiles=[]  
    tiles = core.get_tiles(input_bounds, Tile_Size, overlap=0)
    # debug 
    numero = len(tiles)
  # ---------------------------------------------------------------------------
  # run tiles 
  # ---------------------------------------------------------------------------  
  for timepoint in range(input_bounds.t1, input_bounds.t2 + 1):    
      # -----------------------------------------------------------------------
      for tile1 in tiles:
          # -------------------------------------------------------------------    
          if Second_channel >= 1:  
              buffer = np.empty((tile1.z2 - tile1.z1 + 1, tile1.y2 - tile1.y1 + 1, tile1.x2 - tile1.x1 + 1, 2), dtype=pixel_type)  
              buffer[:,:,:,0] = input_data.read_imagedata(region=tile1, channel=Input_channel, timepoint=timepoint)               
              buffer[:,:,:,1] = input_data.read_imagedata(region=tile1, channel=Second_channel, timepoint=timepoint)     
          else:        
              buffer = np.empty((tile1.z2 - tile1.z1 + 1, tile1.y2 - tile1.y1 + 1, tile1.x2 - tile1.x1 + 1), dtype=pixel_type)
              buffer = input_data.read_imagedata(region=tile1, channel=Input_channel, timepoint=timepoint)
          # -------------------------------------------------------------------        
          # model.Eval 
          # -------------------------------------------------------------------        
          if not Custom_model:
            masks, flows, styles, diams = model.eval(buffer, diameter=diameter, channels=channels_CP, flow_threshold=Flow_threshold, cellprob_threshold=Cellprob_threshold, stitch_threshold=stitch_thr, anisotropy=anisotropy, do_3D=process_3D)
          else:         
            masks, flows, styles = model.eval(buffer, diameter=diameter, channels=channels_CP, flow_threshold=Flow_threshold, cellprob_threshold=Cellprob_threshold, stitch_threshold=stitch_thr, anisotropy=anisotropy, do_3D=process_3D)
          # -------------------------------------------------------------------    
          # prepare segmentBuilder
          # -------------------------------------------------------------------              
          segmentBuilder = arivis_operation.SegmentBuilder.segment_builder(input_data, output_data)
          with segmentBuilder.at_timepoint(timepoint):
              print('tile_bounds is '+ str(tile1.__dict__))
              segmentBuilder.create_segments(masks, tile1)
              # ---------------------------------------------------------------
              workDone = workDone + 1
              context.notify_progress(workDone * 100 / totalWork)
      
      
