https://www.acmicpc.net/problem/10473

 

10473번: 인간 대포

입력은 한 개의 길찾기 문제를 표현한다. 첫 줄에는 두 개의 실수가 입력되며 각각은 당신이 현재 위치한 X, Y좌표이다. 두 번째 줄에는 목적지의 X, Y좌표가 실수로 입력된다. 이어지는 줄에는 대

www.acmicpc.net


풀이 과정


  1. 시작점, 대포 위치, 도착점을 노드로 잡고 걸리는 시간을 행렬로 구해준다.
    • 시작점 -> 0, 대포들 -> (1 ~ N) 도착점 -> N+1
    • [i][j] => 노드 i에서 노드 j까지의 최소 시간
    • 시작점부터 각 노드의 걸리는 시간을 구하는 경우, 대포를 사용할수 없으며 달려가는 케이스밖에 없으므로 [0][j]는 시작점 ~ 도착점 거리를 구한 다음 5로 나누어주면 된다.
    • 이외의 경우에서는 대포를 사용할 수 있으므로
      1. 달려가는 케이스(거리 / 5)
      2. 거리가 50 초과일 때 대포 + 앞으로 달려가기 (2 + (거리-50)/5)
      3. 거리가 50 미만일 때 대포 + 뒤로 달려가기 (2 + (50-거리)/5)
      4. 거리가 50일 때 2
    • 위 네가지 케이스들을 맞추어서 최솟값을 구해서 행렬을 채워주면 된다.
  2. 시작점을 0으로 두고, 다익스트라 알고리즘을 사용한다.
    • 현재 최소 거리인 노드를 잡고, 이동 가능한 모든 노드를 이동해 보면서, 거리를 갱신하는 경우 거리 갱신 및 최소 히프에 거리 저장
    • 따라서, 최소 히프를 사용하면 좀 더 빠르게 구현 가능하다.
  3. N+1 노드까지의 최소 거리를 출력한다.

일반 다익스트라 알고리즘보다 어려운거 같은데 정답률은 높음.. 뭐지..?

 


소스 코드


import sys
import heapq
import math

input = lambda : sys.stdin.readline().rstrip()

def calculate_distance(src, dest):
    return math.sqrt((src[0]-dest[0]) ** 2 + (src[1]-dest[1]) ** 2)

myX, myY = map(float, input().split())
targetX, targetY = map(float, input().split())
N = int(input())

points = [list(map(float, input().split())) for _ in range(N)]
points = [[myX, myY]] + points + [[targetX, targetY]]

time_matrix = [[0] * (N+2) for _ in range(N+2)]

for i in range(len(points)):
    for j in range(i+1, len(points)):
        if i == 0:
            time_matrix[i][j] = calculate_distance(points[i], points[j]) / 5
        else:
            distance = calculate_distance(points[i], points[j])
            time_matrix[i][j] = distance / 5
            if distance > 50.0:
                time_matrix[i][j] = min(time_matrix[i][j],
                                        2 + (distance-50) / 5)
            elif distance == 50.0:
                time_matrix[i][j] = 2.0
            else:
                time_matrix[i][j] = min(time_matrix[i][j],
                                        2 + (50-distance) / 5)
        time_matrix[j][i] = time_matrix[i][j]
        
INT_MAX = int(10e9)
distance = [INT_MAX] * (N+2)
distance[0] = 0
heap = [[0, 0]]

while heap:
    time, next_node = heapq.heappop(heap)
    if distance[next_node] != time:
        continue

    for i in range(len(points)):
        if next_node == i:
            continue
        if distance[i] > time + time_matrix[i][next_node]:
            distance[i] = time + time_matrix[i][next_node]
            heapq.heappush(heap, [distance[i], i])

print("%.6f"%(distance[N+1]))

 

+ Recent posts