#!/usr/bin/env python
from __future__ import print_function

from gazebo_msgs.msg import ModelState, ModelStates
from sensor_msgs.msg import LaserScan, JointState
from race.msg import drive_param
from std_msgs.msg import Float32
from geometry_msgs.msg import Pose

import roslib
import rospy

from scipy.spatial.transform import Rotation
from gym.utils import seeding
from copy import deepcopy
from gym import spaces
import numpy as np
import random
import threading
import time
import gym
import sys

class Env(gym.Env):
  def __init__(self):
    rospy.init_node('get_state', anonymous=True)

    self.drive_msg = drive_param()
    self.drive_msg.angle = 0.0
    self.drive_msg.velocity = 0.0
    self.angle_range = np.arange(40, 1080, 40) # -135 degree to 135 degree per 10 degrees    
    self.rate = rospy.Rate(10) #unit : hz
    self.limit_distance = 0.6
    self.max_step = 2000
    self.cur_step = 0
    self.t_last = time.time()
    self.track = '2'
    self.idx = 0

    # state & action
    self.sensor_value = np.zeros_like(self.angle_range, dtype=np.float32)
    self.rpm_data = 0.0
    self.steering = 0.0

    #state & action dimension
    self.num_action = 6 # left, straight, right + stop, go
    self.state_dim = len(self.sensor_value) + 2 # lidar + velocity + steering
    self.action_dim = 2
    #self.action_space = spaces.Discrete(self.num_action)
    self.action_space = spaces.Box(-np.ones(self.action_dim), np.ones(self.action_dim), dtype=np.float32)
    self.observation_space = spaces.Box(-np.inf*np.ones(self.state_dim), np.inf*np.ones(self.state_dim), dtype=np.float32)

    #publisher and subsrciber
    self.drive_pub = rospy.Publisher('/drive_parameters', drive_param , queue_size = 1)
    self.init_pub = rospy.Publisher('/gazebo/set_model_state', ModelState , queue_size = 1)
    self.state_sub = rospy.Subscriber('/scan', LaserScan, self.state_callback) # -135 deg ~ 135 deg, length : 1081, 4 per 1 deg.
    self.rpm_sub = rospy.Subscriber('/ang_vel_data', Float32, self.rpm_callback)
    self.pos_sub = rospy.Subscriber('/gazebo/model_states', ModelStates, self.pos_callback)

    # define drive thread
    self.drive_thread = threading.Thread(target=self.drive_pub_thread, args=())
    self.drive_thread.daemon=True
    self.drive_thread_loop = True
    self.drive_thread_flag = True
    self.drive_thread.start()

    init_quat = Rotation.from_rotvec([-0.5, -0.5, 0.01]).as_quat()
    init_pos = np.array([5.0, 5.0, 0.01])
    
    
    self.pos = np.zeros((20, 3))
    self.quat = np.zeros((20, 4))
    
    self.cur_pos = Pose()
    if self.track == '2':
        self.pos[0] = np.array([0.0, 0.0, 0.01])
        self.quat[0] = Rotation.from_rotvec([0, 0, np.pi/2]).as_quat()
        
        self.pos[1] = np.array([0.0, 4.0, 0.01])
        self.quat[1] = Rotation.from_rotvec([0, 0, np.pi/2]).as_quat()
        
        self.pos[2] = np.array([0.0, 3.0, 0.01])
        self.quat[2] = Rotation.from_rotvec([0, 0, np.pi/2]).as_quat()
        
        self.pos[3] = np.array([-3.5, 8.0, 0.01])
        self.quat[3] = Rotation.from_rotvec([0, 0, 0]).as_quat()
        
        self.pos[4] = np.array([-8.0, 6.0, 0.01])
        self.quat[4] = Rotation.from_rotvec([0, 0, 0]).as_quat()
        
        self.pos[5] = np.array([-6.0, 6.0, 0.01])
        self.quat[5] = Rotation.from_rotvec([0, 0, np.pi]).as_quat()
        
        self.pos[6] = np.array([-10.0, 8.0, 0.01])
        self.quat[6] = Rotation.from_rotvec([0, 0, np.pi]).as_quat()
        
        self.pos[7] = np.array([-13.0, 7.0, 0.01])
        self.quat[7] = Rotation.from_rotvec([0, 0, -np.pi/4*3]).as_quat()
        
        self.pos[8] = np.array([-15.0, 6.0, 0.01])
        self.quat[8] = Rotation.from_rotvec([0, 0, -np.pi/4*3]).as_quat()
        
        self.pos[9] = np.array([-15.0, 3.0, 0.01])
        self.quat[9] = Rotation.from_rotvec([0, 0, np.pi/2]).as_quat()
        
        self.pos[10] = np.array([-15.0, -3.0, 0.01])
        self.quat[10] = Rotation.from_rotvec([0, 0, np.pi/2]).as_quat()
        
        self.pos[11] = np.array([-14.0, -5.0, 0.01])
        self.quat[11] = Rotation.from_rotvec([0, 0, -np.pi/2]).as_quat()
        
        self.pos[12] = np.array([-14.0, -8.0, 0.01])
        self.quat[12] = Rotation.from_rotvec([0, 0, -np.pi/4]).as_quat()
        
        self.pos[13] = np.array([-12.0, -9.5, 0.01])
        self.quat[13] = Rotation.from_rotvec([0, 0, np.pi/2]).as_quat()
        
        self.pos[14] = np.array([-10.0, -11.0, 0.01])
        self.quat[14] = Rotation.from_rotvec([0, 0, -np.pi/4]).as_quat()
        
        self.pos[15] = np.array([-10.0, -11.0, 0.01])
        self.quat[15] = Rotation.from_rotvec([0, 0, -np.pi/4]).as_quat()
        
        self.pos[16] = np.array([-14.0, -5.0, 0.01])
        self.quat[16] = Rotation.from_rotvec([0, 0, -np.pi/2]).as_quat()
        
        self.pos[17] = np.array([-6.0, -12.5, 0.01])
        self.quat[17] = Rotation.from_rotvec([0, 0, np.pi]).as_quat()
        
        self.pos[18] = np.array([-6.0, -12.5, 0.01])
        self.quat[18] = Rotation.from_rotvec([0, 0, -np.pi]).as_quat()
        
        self.pos[19] = np.array([-3.0, -12.5, 0.01])
        self.quat[19] = Rotation.from_rotvec([0, 0, 0]).as_quat()
        
    if self.track == '4':
        self.pos[0] = np.array([-1.5, 5.0, 0.01])
        self.quat[0] = Rotation.from_rotvec([0, 0, np.pi/2]).as_quat()
        
        self.pos[1] = np.array([-1.5, 8.0, 0.01])
        self.quat[1] = Rotation.from_rotvec([0, 0, -np.pi/2]).as_quat()
        
        self.pos[2] = np.array([-1.5, -4.0, 0.01])
        self.quat[2] = Rotation.from_rotvec([0, 0, -np.pi/2]).as_quat()
        
        self.pos[3] = np.array([-1.5, -8.0, 0.01])
        self.quat[3] = Rotation.from_rotvec([0, 0, -np.pi]).as_quat()
        
        self.pos[4] = np.array([-8, -8.0, 0.01])
        self.quat[4] = Rotation.from_rotvec([0, 0, 0]).as_quat()
        
        self.pos[5] = np.array([-16.0, -8.0, 0.01])
        self.quat[5] = Rotation.from_rotvec([0, 0, np.pi]).as_quat()
        
        self.pos[6] = np.array([-20.0, -8.0, 0.01])
        self.quat[6] = Rotation.from_rotvec([0, 0, np.pi/2]).as_quat()
        
        self.pos[7] = np.array([-23.0, 8.0, 0.01])
        self.quat[7] = Rotation.from_rotvec([0, 0, -np.pi/2]).as_quat()
        
        self.pos[8] = np.array([-23.0, 8.0, 0.01])
        self.quat[8] = Rotation.from_rotvec([0, 0, np.pi/2]).as_quat()
        
        self.pos[9] = np.array([-23.0, 16.0, 0.01])
        self.quat[9] = Rotation.from_rotvec([0, 0, 0]).as_quat()
        
        self.pos[10] = np.array([-17.0, 16.0, 0.01])
        self.quat[10] = Rotation.from_rotvec([0, 0, np.pi]).as_quat()
        
        self.pos[11] = np.array([-17.0, 16.0, 0.01])
        self.quat[11] = Rotation.from_rotvec([0, 0, 0]).as_quat()
        
        self.pos[12] = np.array([-8.0, 16.0, 0.01])
        self.quat[12] = Rotation.from_rotvec([0, 0, 0]).as_quat()
        
        self.pos[13] = np.array([-8.0, 16.0, 0.01])
        self.quat[13] = Rotation.from_rotvec([0, 0, np.pi]).as_quat()
        
        self.pos[14] = np.array([-8, -8.0, 0.01])
        self.quat[14] = Rotation.from_rotvec([0, 0, 0]).as_quat()
        
        self.pos[15] = np.array([-17.0, 16.0, 0.01])
        self.quat[15] = Rotation.from_rotvec([0, 0, -np.pi]).as_quat()
        
        self.pos[16] = np.array([-23.0, 8.0, 0.01])
        self.quat[16] = Rotation.from_rotvec([0, 0, np.pi/2]).as_quat()
        
        self.pos[17] = np.array([-23.0, 8.0, 0.01])
        self.quat[17] = Rotation.from_rotvec([0, 0, -np.pi/2]).as_quat()
        
        self.pos[18] = np.array([-16.0, -8.0, 0.01])
        self.quat[18] = Rotation.from_rotvec([0, 0, np.pi]).as_quat()
        
        self.pos[19] = np.array([-8, -8.0, 0.01])
        self.quat[19] = Rotation.from_rotvec([0, 0, 0]).as_quat()
        
    if self.track == '5':
        self.pos[0] = np.array([0.0, 0.0, 0.01])
        self.quat[0] = Rotation.from_rotvec([0, 0, np.pi/2]).as_quat()
        
        self.pos[1] = np.array([0.0, 4.0, 0.01])
        self.quat[1] = Rotation.from_rotvec([0, 0, np.pi/2]).as_quat()
        
        self.pos[2] = np.array([0.0, 3.0, 0.01])
        self.quat[2] = Rotation.from_rotvec([0, 0, np.pi/2]).as_quat()
        
        self.pos[3] = np.array([-2.5, 8.0, 0.01])
        self.quat[3] = Rotation.from_rotvec([0, 0, np.pi]).as_quat()
        
        self.pos[4] = np.array([-8.0, 6.0, 0.01])
        self.quat[4] = Rotation.from_rotvec([0, 0, 0]).as_quat()
        
        self.pos[5] = np.array([-6.0, 6.0, 0.01])
        self.quat[5] = Rotation.from_rotvec([0, 0, np.pi]).as_quat()
        
        self.pos[6] = np.array([-10.0, 8.0, 0.01])
        self.quat[6] = Rotation.from_rotvec([0, 0, np.pi]).as_quat()
        
        self.pos[7] = np.array([-13.0, 7.0, 0.01])
        self.quat[7] = Rotation.from_rotvec([0, 0, -np.pi/4*3]).as_quat()
        
        self.pos[8] = np.array([-15.0, 6.0, 0.01])
        self.quat[8] = Rotation.from_rotvec([0, 0, -np.pi/4*3]).as_quat()
        
        self.pos[9] = np.array([-15.0, 3.0, 0.01])
        self.quat[9] = Rotation.from_rotvec([0, 0, np.pi/2]).as_quat()
        
        self.pos[10] = np.array([-15.0, -3.0, 0.01])
        self.quat[10] = Rotation.from_rotvec([0, 0, np.pi/2]).as_quat()
        
        self.pos[11] = np.array([-16.5, -5.0, 0.01])
        self.quat[11] = Rotation.from_rotvec([0, 0, -np.pi/2]).as_quat()
        
        self.pos[12] = np.array([-15.5, -8.0, 0.01])
        self.quat[12] = Rotation.from_rotvec([0, 0, np.pi/2]).as_quat()
        
        self.pos[13] = np.array([-13.0, -9.5, 0.01])
        self.quat[13] = Rotation.from_rotvec([0, 0, np.pi]).as_quat()
        
        self.pos[14] = np.array([-10.0, -11.0, 0.01])
        self.quat[14] = Rotation.from_rotvec([0, 0, -np.pi/4]).as_quat()
        
        self.pos[15] = np.array([-11.0, -11.0, 0.01])
        self.quat[15] = Rotation.from_rotvec([0, 0, -np.pi/4]).as_quat()
        
        self.pos[16] = np.array([-15.0, -6.0, 0.01])
        self.quat[16] = Rotation.from_rotvec([0, 0, -np.pi/2]).as_quat()
        
        self.pos[17] = np.array([-6.0, -12.5, 0.01])
        self.quat[17] = Rotation.from_rotvec([0, 0, np.pi]).as_quat()
        
        self.pos[18] = np.array([-6.0, -12.5, 0.01])
        self.quat[18] = Rotation.from_rotvec([0, 0, -np.pi]).as_quat()
        
        self.pos[19] = np.array([-3.0, -12.5, 0.01])
        self.quat[19] = Rotation.from_rotvec([0, 0, 0]).as_quat()
        
    
    model_states = rospy.wait_for_message('gazebo/model_states', ModelStates)
    self.init_state = ModelState()
    idx = model_states.name.index('racecar')
    self.init_state.model_name = model_states.name[idx]
    self.init_state.pose = model_states.pose[idx]
    self.init_state.twist = model_states.twist[idx]
    self.init_state.reference_frame = ''
    
    self.init_state.pose.position.x = init_pos[0]
    self.init_state.pose.position.y = init_pos[1]
    self.init_state.pose.position.z = init_pos[2]
    self.init_state.pose.orientation.x = init_quat[0]
    self.init_state.pose.orientation.y = init_quat[1]
    self.init_state.pose.orientation.z = init_quat[2]
    self.init_state.pose.orientation.w = init_quat[3]
    
  def state_callback(self, data):
    '''
    self.state = data
    self.state = {'ranges':data.ranges, 'intensities':data.intensities, 'angle_min':data.angle_min, 'angle_max':data.angle_max, \
                  'range_min':data.range_min, 'range_max':data.range_max}
    '''
    for i, idx in enumerate(self.angle_range):
      self.sensor_value[i] = np.clip(np.mean(data.ranges[idx-40:idx+40]), 0.0, 10.0)

  def rpm_callback(self, data):
    self.rpm_data = data.data

  def pos_callback(self, data):
    self.cur_pos = data.pose[data.name.index('racecar')]

  def drive_pub_thread(self):
    while self.drive_thread_loop:
      self.drive_pub.publish(self.drive_msg)
      self.drive_thread_flag = True
      time.sleep(0.01)
    self.drive_thread_flag = False

  def reset(self):
    self.cur_step = 0
    self.rpm_data = 0.0
    self.steering = 0.0
    self.drive_msg.angle = 0.0
    self.drive_msg.velocity = 0.0
    time.sleep(1)
    self.rate.sleep()

    #idx = random.randint(0, 19)
    idx = 0
    
    self.idx += 1
    print('\nreset environment')
    
    self.init_state.pose.position.x = self.pos[idx][0]
    self.init_state.pose.position.y = self.pos[idx][1]
    self.init_state.pose.position.z = self.pos[idx][2]
    self.init_state.pose.orientation.x = self.quat[idx][0]
    self.init_state.pose.orientation.y = self.quat[idx][1]
    self.init_state.pose.orientation.z = self.quat[idx][2]
    self.init_state.pose.orientation.w = self.quat[idx][3]
    
    self.init_pub.publish(self.init_state)
    time.sleep(1)
    self.rate.sleep()

    state = self.get_state()
    return state
    
  def drive(self, vel, steer):
    self.steering = steer
    self.drive_msg.angle = steer
    self.drive_msg.velocity = vel
    
  def sigmoid(self, x):
    return 2 * (1 / (1+np.exp(-x)))

  def step(self, action, vel):
    self.cur_step += 1

    steer_scale = 1.0
    vel_scale = 100.0

    velocity = vel
    steering = action
      
    steering = steering*steer_scale
    velocity = np.clip(velocity*vel_scale, 0.0, np.inf)
    
    self.drive(velocity, steering)
    self.rate.sleep()

    state = self.get_state()
    reward = vel
    done = False
    over = False
    reward += np.min(self.sensor_value) * 0.2
    if np.min(self.sensor_value) < self.limit_distance:
      done = True
      reward -= 10.0

    if self.cur_step >= self.max_step:
      print('done')
      done = True
      over = True

    if done:
      self.drive_msg.angle = 0.0
      self.drive_msg.velocity = 0.0

    return state, reward, done, over

  def get_state(self):
    state = np.concatenate([self.sensor_value/10.0, [self.rpm_data/100.0], [self.steering]])
    return state

  def get_pose(self):
    return self.cur_pos

  def seed(self, seed=None):
    self.np_random, seed = seeding.np_random(seed)
    return [seed]

  def render(self, mode='human'):
    pass

  def close(self):
    self.drive_thread_loop = False
    while self.drive_thread_flag:
      time.sleep(0.001)
    self.drive_thread.join()


if __name__ == '__main__':
  env = Env()
  
  for i in range(3):
    print('{} episode start!'.format(i+1))
    s_t = env.reset()

    while True:
      s_t, r_t, done, over = env.step(action)
      if done:
        break
    print(over)
  env.close()

