algorithm/백준알고리즘

[백준알고리즘] 2104번: 부분배열 고르기-C++

SURI:) 2022. 9. 17. 23:14
728x90

2104번: 부분배열 고르기 (acmicpc.net)

 

2104번: 부분배열 고르기

크기가 N(1 ≤ N ≤ 100,000)인 1차원 배열 A[1], …, A[N]이 있다. 어떤 i, j(1 ≤ i ≤ j ≤ N)에 대한 점수는, (A[i] + … + A[j]) × min{A[i], …, A[j]}가 된다. 즉, i부터 j까지의 합에 i부터 j까지의 최솟값을 곱

www.acmicpc.net

문제를 푸는데 되게 오래걸렸다.

처음에는 세그먼트 트리를 오랜만에 공부하려고 시작한 거라, 공부하는데 시간이 소요됐다. 그러고 나서는 어떻게 적용되면 좋을지 고민하느라 시간이 소요됐다. 마지막으로, 알 수 없는 '틀렸습니다'의 늪에서 헤어나오지 못했다.

게다가 하루에 30분 저도씩 밖에 시간을 못 쓰는 바람에 며칠이 걸린건지 모르겠다..

 

'틀렸습니다'의 늪에 갇혀있었는데, 구글링을 통해 나랑 같은 방법으로 푼 분의 글을 봤고 도움을 받아 해결할수 있었다.


문제 풀이

풀이에 들어가기 앞서, 세그먼트 트리에 대해서는 Crocus님의 글나동빈님의 글을 보고 다시 익혔다.

 

문제를 해결하기 위해 세그먼트 트리를 2개 이용했다. 즉, 구간별 구하는 값이 2개인 것이다.

하나는 구간별 값들의 합을 위해 사용했다. 다른 하나는 구간별 최소값이 위치한 인덱스를 위해 사용했다. 즉, 이 2가지 세그먼트 트리를 이용해 특정 구간의 값들의 합과 그중 최소값을 구할 수 있게 된다.

 

이를 이용해 문제에서 요구하는 최대값을 구해나가면 된다. 최대값을 구해나가는 방법은 분할 정복과 같이 진행하면 된다.

구간 \([L, R]\) 사이에 최소값의 위치가 \(S\)이라면, \([L, R]\) 에서의 값과, \([L, S-1]\)에서의 값, 그리고 \([S+1, R]\) 에서의 값 중 큰 값을 구하면 된다.

 

'틀렸습니다'의 늪에서 헤어나오기 위해 다른 풀이들을 검색해봤는데, 분할 정복 방법만을 이용해 푸는 경우가 많은 것 같았다. 이 경우에는 최소값 위치와 상관없이 구간 \([L, R]\) 사이의 중앙값 \(M\)을 기준으로 범위 \([L, R]\), \([L, M-1]\), \([M+1, R]\) 구간에서의 값들 중 최대값을 구하도록 했다. 따지면 왼쪽 구간에서의 최대값이 있는 경우, 오른쪽 구간에서 최대값이 있는 경우, 양쪽 구간에 걸쳐서 최대값이 있는 경우다.

풀이를 보면 간단한 분할 정복 종류 중 하나인데, 이것도 생각이 안났었다.. 나도 다시 풀도록 해야겠다.

 

코드 자체가 같은 형태가 많아 더 정리하고 싶은데.. 이 문제에 너무 많은 시간을 쓰는 바람에 더 시간 쓰고 싶지 않아졌다..

#include <cstdio>
#include <cstdint>

void inputArray();
void initMinIndexSegmentTree(const int &nodeIndex, const int &left, const int &right);
void initSumSegmentTree(const int &nodeIndex, const int &left, const int &right);
uint64_t getMaxVal(const int &currentRangeStart, const int &currentRangeEnd);

int N;
int array[100001];
uint64_t sumSegTree[400004];
int minIndexSegTree[400004];

int main(void) {
	inputArray();

	initSumSegmentTree(1, 1, N);
	initMinIndexSegmentTree(1, 1, N);

	printf("%llu", getMaxVal(1, N));
	return 0;
}

void inputArray() {
	::scanf("%d", &N);
	for ( int index = 1; index <= N; ++index ) {
		::scanf("%d", &array[index]);
	}
}

void initMinIndexSegmentTree(const int &nodeIndex, const int &left, const int &right) {
	if ( left == right ) {
		minIndexSegTree[nodeIndex] = left;
		return;
	}

	const int mid = static_cast<int>((left + right) / 2);
	const int leftChildNodeIndex = nodeIndex * 2;
	const int rightChildNodeIndex = leftChildNodeIndex + 1;

	initMinIndexSegmentTree(leftChildNodeIndex, left, mid);
	initMinIndexSegmentTree(rightChildNodeIndex, mid + 1, right);

	const int &leftIndexOfSmallestValue = minIndexSegTree[leftChildNodeIndex];
	const int &rightIndexOfSmallestValue = minIndexSegTree[rightChildNodeIndex];
	minIndexSegTree[nodeIndex] = (array[leftIndexOfSmallestValue] < array[rightIndexOfSmallestValue]) ? leftIndexOfSmallestValue : rightIndexOfSmallestValue;
}

void initSumSegmentTree(const int &nodeIndex, const int &left, const int &right) {
	if ( left == right ) {
		sumSegTree[nodeIndex] = static_cast<uint64_t>(array[left]);
		return;
	}

	const int mid = static_cast<int>((left + right) / 2);
	const int leftChildNodeIndex = nodeIndex * 2;
	const int rightChildNodeIndex = leftChildNodeIndex + 1;

	initSumSegmentTree(leftChildNodeIndex, left, mid);
	initSumSegmentTree(rightChildNodeIndex, mid + 1, right);

	const uint64_t &leftSum = sumSegTree[leftChildNodeIndex];
	const uint64_t &rightSum = sumSegTree[rightChildNodeIndex];
	sumSegTree[nodeIndex] = leftSum + rightSum;
}

int getIndex(const int &currentRangeStart, const int &currentRangeEnd, const int &nodeIndex, const int &left, const int &right) {
	// 현재 범위가 검색 대상 안에 포함된 경우
	if ( left <= currentRangeStart && currentRangeEnd <= right ) {
		return minIndexSegTree[nodeIndex];
	}
	// 현재 범위가 검색 대상 밖인 경우
	if ( left > currentRangeEnd || currentRangeStart > right ) {
		return 0;
	}
	// 현재 범위 중 일부가 검색 대상에 포함된 경우
	const int currentRangeMid = static_cast<const int>((currentRangeStart + currentRangeEnd) / 2);
	const int leftChildNodeIndex = nodeIndex * 2;
	const int rightChildNodeIndex = leftChildNodeIndex + 1;

	const int leftChildIndex  = getIndex(currentRangeStart, currentRangeMid, leftChildNodeIndex, left, right);
	const int rightChildIndex = getIndex(currentRangeMid + 1, currentRangeEnd, rightChildNodeIndex, left, right);
	if ( 0 == leftChildIndex && 0 == rightChildIndex ) {
		return 0;
	}
	if ( 0 == leftChildIndex && 0 < rightChildIndex ) {
		return rightChildIndex;
	}
	if ( 0 < leftChildIndex && 0 == rightChildIndex ) {
		return leftChildIndex;
	}
	return (array[leftChildIndex] < array[rightChildIndex]) ? leftChildIndex : rightChildIndex;
}

uint64_t getSum(const int &currentRangeStart, const int &currentRangeEnd, const int &nodeIndex, const int &left, const int &right) {
	// 현재 범위가 검색 대상 안에 포함된 경우
	if ( left <= currentRangeStart && currentRangeEnd <= right ) {
		return sumSegTree[nodeIndex];
	}
	// 현재 범위가 검색 대상 밖인 경우
	if ( left > currentRangeEnd || currentRangeStart > right ) {
		return 0;
	}
	// 현재 범위 중 일부가 검색 대상에 포함된 경우
	const int currentRangeMid = static_cast<const int>((currentRangeStart + currentRangeEnd) / 2);
	const int leftChildNodeIndex = nodeIndex * 2;
	const int rightChildNodeIndex = leftChildNodeIndex + 1;

	const uint64_t leftChildSum  = getSum(currentRangeStart, currentRangeMid, leftChildNodeIndex, left, right);
	const uint64_t rightChildSum = getSum(currentRangeMid + 1, currentRangeEnd, rightChildNodeIndex, left, right);
	return leftChildSum + rightChildSum;
}

uint64_t getMaxVal(const int &currentRangeStart, const int &currentRangeEnd) {
	// 범위가 벗어난 경우
	if ( currentRangeStart > currentRangeEnd ) {
		return 0;
	}
	// 범위가 1개인 경우
	if ( currentRangeStart == currentRangeEnd ) {
		return static_cast<uint64_t>(array[currentRangeStart]) * static_cast<uint64_t>(array[currentRangeStart]);
	}

	// 현재 범위의 값 구하기
	const int indexOfSmallestVal = getIndex(1, N, 1, currentRangeStart, currentRangeEnd);
	const uint64_t partialSum = getSum(1, N, 1, currentRangeStart, currentRangeEnd);
	uint64_t maxVal = partialSum * array[indexOfSmallestVal];

	// 최소값이 있는 인덱스를 기준으로 나뉜 범위의 값 구하기
	const uint64_t leftMaxVal = getMaxVal(currentRangeStart, indexOfSmallestVal - 1);
	const uint64_t rightMaxVal = getMaxVal(indexOfSmallestVal + 1, currentRangeEnd);

	// 최대값 구하기
	maxVal = (maxVal < leftMaxVal) ? leftMaxVal : maxVal;
	maxVal = (maxVal < rightMaxVal) ? rightMaxVal : maxVal;
	return maxVal;
}
728x90