
# coding: utf-8
#
# NAME: YOLO_detecet_operator_v4
# FILE: YOLO_detecet_operator_v4.py                          
# REVISION : 2023-10-30
# AUTHOR : Kenneth Gao (Altos labs)/ Mariia Burdyniuk
 
# Copyright(c) 2023 arivis AG, Germany. All Rights Reserved.
#
# Permission is granted to use, modify and distribute this code,
# as long as this copyright notice remains part of the code.

# PURPOSE: to run inference with the 2D Yolo model and create bounding-box objects.
# the script runs plane-wise on the 3D data and 2D on the time series. 
    
# Tested for arivis Pro Release : 4.1.1
# NOTE: Tiling has been added to address large datasets 

#%%
import numpy as np
import pandas as pd

import arivis_parameter, arivis_operation
import arivis_core as core
import arivis_objects as objects
#import math

import cv2
import matplotlib.pyplot as plt
from skimage.morphology import skeletonize, thin, dilation, skeletonize_3d, medial_axis
from skan import Skeleton, draw
#%%
@arivis_parameter.add_arivis_parameters(Input_Tag = "[ENTER_OBJECT_TAG_HERE]", Input_channel = 1, Width_in_pixels_vertical = 200, Width_in_pixels_horizontal = 100)
@arivis_parameter.add_param_description(Input_Tag = "Objects with specified tag are processed.", \
                                        Input_channel = "This channel is for the prediction",\
                                        Width_in_pixels_vertical = "Vertical Width of the segment boxes",\
                                        Width_in_pixels_horizontal = "Horizontal Width of the segment boxes")
def main(Input_Tag, Input_channel, Width_in_pixels_vertical, Width_in_pixels_horizontal): 
  
  Input_channel  -= 1     # change to zero-based indexing
  context = arivis_operation.Operation.get_context()
  input_data = context.get_input()
  output_data = context.get_output()
  # input_bounds = input_data.get_bounds() # bounding box 4D
  IdList = input_data.get_object_ids(Input_Tag)
  print(str(len(IdList)), ' objects with the given tag found. Start processing...')
    
  
    
  for indice in range(0,int(len(IdList))): 
    Object1 = input_data.get_object(IdList[indice],True)
                   
    objBounds  = Object1.get_bounds() 
    obj_bounds_2D = core.Bounds2D() # initianlizing 2D bounds
    obj_bounds_2D.x = objBounds.x1
    obj_bounds_2D.y = objBounds.y1  
    obj_bounds_2D.height = objBounds.y2 - objBounds.y1 + 1
    obj_bounds_2D.width  = objBounds.x2 - objBounds.x1 + 1

    #initualize 3D buffer for each object
    obj_buffer_3D = np.zeros((obj_bounds_2D.height,obj_bounds_2D.width, (objBounds.z2 - objBounds.z1 +1)), dtype=bool)
    print(obj_buffer_3D.size)
    planecount = 0
    for plane in range(objBounds.z1,objBounds.z2+1):
        #initualize 2D buffer, pixel type same as imageset
        obj_buffer = np.zeros(( obj_bounds_2D.height,obj_bounds_2D.width), dtype=bool)
        Object1.get_mask(obj_buffer, plane, obj_bounds_2D,1)
        
        obj_buffer_3D[:, :, planecount] = obj_buffer
        planecount = planecount + 1
  #number_of_timepoints = (objBounds.t2 - objBounds.t1 + 1)
  
    totalWork = objBounds.z2 - objBounds.z1 + 1
    workDone = 0
    mid_plane = (objBounds.z2 - objBounds.z1)//2
    edges_new = obj_buffer_3D[:,:,mid_plane].astype(np.uint8)*255
    kernel = np.ones((7, 7), np.uint8)
    img_dilation = cv2.dilate(edges_new, kernel, iterations=5) 
    #plt.imshow(img_dilation)
    thin_edges = thin(img_dilation, max_num_iter=150)
    # plt.imshow(thin_edges)
    # thin_edges = thin(edges_new, max_num_iter=180)
  
    # Skeletonize the binary image to find the centerline
    skeleton = skeletonize(thin_edges)
  
    
    skeleton = Skeleton(skeleton)
    
    path_coordinates = skeleton.coordinates[skeleton.path(0)]
  
        
  
  
  
    segment_points = []
    segment_color = (0, 0, 0)  
    gray_tmp = edges_new.copy()
    for i in range(0,len(path_coordinates),Width_in_pixels_vertical):
       #print(path_coordinates[i][1], path_coordinates[i][0])
        cv2.circle(gray_tmp, (path_coordinates[i][1], path_coordinates[i][0]), 2, segment_color, 50)
        if(i+Width_in_pixels_vertical < len(path_coordinates)):
            X = [path_coordinates[i][1], path_coordinates[i+Width_in_pixels_vertical][1]]
            Y = [path_coordinates[i][0], path_coordinates[i+Width_in_pixels_vertical][0]]
            new_point1, new_point2 = draw_normal_line(gray_tmp, (X[0], Y[0]), (X[1], Y[1]), length=Width_in_pixels_horizontal, color=(0, 0, 0), thickness=20)
           
        else:
            X = [path_coordinates[i][1], path_coordinates[i][1]+Width_in_pixels_vertical]
            Y = [path_coordinates[i][0], path_coordinates[i][0]]
            new_point1, new_point2 = draw_normal_line(gray_tmp, (X[0], Y[0]), (X[1], Y[1]), length=Width_in_pixels_horizontal, color=(0, 0, 0), thickness=20)
  
        segment_points.append(new_point1)
        segment_points.append(new_point2)
        
       
  
  
    origin_offsetx = obj_bounds_2D.x
    origin_offsety = obj_bounds_2D.y
    #for plane in range(objBounds.z1, objBounds.z2):
    for i in range(0, len(segment_points)-4+1, 2):
        pt1, pt2, pt3, pt4 = segment_points[i:i+4]
        p1 = core.Point2D(int(pt1[0] + origin_offsetx), int(pt1[1] + origin_offsety))            
        p2 = core.Point2D(int(pt2[0] + origin_offsetx), int(pt2[1] + origin_offsety)) 
        p3 = core.Point2D(int(pt3[0] + origin_offsetx), int(pt3[1] + origin_offsety)) 
        p4 = core.Point2D(int(pt4[0] + origin_offsetx), int(pt4[1] + origin_offsety)) 
        print(pt1, pt2, pt3, pt4)
        segmented_obj = createSegment(p1, p2, p3, p4, objBounds.z1, objBounds.z2 + 1)
        output_data.add_object(segmented_obj)
        workDone = workDone + 1
        context.notify_progress(workDone * 100 / totalWork)
#%%
# create a dummy segment as a bounding box around the center point
def createSegment(p1, p2, p3, p4, plane_begin, plane_end):
    new_object = objects.Segment()  
    #new_object.set_timepoint(timepoint)

    polygon = objects.Polygon()   
    pointList = []
    
    # p1 = core.Point2D(int(x), int(y))            
    # p2 = core.Point2D(int(x + width), int(y))
    # p3 = core.Point2D(int(x), int(y + height))
    # p4 = core.Point2D(int(x + width), int(y + height))
    
    pointList.append(p1) 
    pointList.append(p2) 
    pointList.append(p4) 
    pointList.append(p3) 
         
    polygon.set_contour(pointList)
    for plane in range(plane_begin, plane_end):
        new_object.add_polygon(polygon, plane)                    
    return new_object


def draw_normal_line(img, line_start, line_end, length, color=(0, 0, 255), thickness=2):
    # Calculate the direction vector of the original line
 
    direction_vector = np.array(line_end) - np.array(line_start)
    
    # Calculate the normal vector
    normal_vector = np.array([-direction_vector[1], direction_vector[0]])
    
    # Normalize the normal vector
    normal_vector = normal_vector / np.linalg.norm(normal_vector)
    
    # Calculate the mid-point of the original line
    mid_point = (line_start[0], line_start[1])
    
    # Calculate the start and end points of the normal line
    normal_line_start = tuple(np.round(mid_point - length * normal_vector).astype(int))
    normal_line_end = tuple(np.round(mid_point + length * normal_vector).astype(int))
    
    if (abs(np.round(normal_vector[0])).astype(int)==0 or abs(np.round(normal_vector[0])).astype(int)==1 ):
        normal_line_start = tuple(np.round(mid_point + length * normal_vector).astype(int))
        normal_line_end = tuple(np.round(mid_point - length * normal_vector).astype(int))
    
    # Draw the normal line
    #cv2.line(img, normal_line_start, normal_line_end, color, thickness)
    return normal_line_start, normal_line_end

