Medium
A magician has various spells.
You are given an array power
, where each element represents the damage of a spell. Multiple spells can have the same damage value.
It is a known fact that if a magician decides to cast a spell with a damage of power[i]
, they cannot cast any spell with a damage of power[i] - 2
, power[i] - 1
, power[i] + 1
, or power[i] + 2
.
Each spell can be cast only once.
Return the maximum possible total damage that a magician can cast.
Example 1:
Input: power = [1,1,3,4]
Output: 6
Explanation:
The maximum possible damage of 6 is produced by casting spells 0, 1, 3 with damage 1, 1, 4.
Example 2:
Input: power = [7,1,6,6]
Output: 13
Explanation:
The maximum possible damage of 13 is produced by casting spells 1, 2, 3 with damage 1, 6, 6.
Constraints:
1 <= power.length <= 105
1 <= power[i] <= 109
import java.util.Arrays;
public class Solution {
public long maximumTotalDamage(int[] power) {
int maxPower = 0;
for (int p : power) {
if (p > maxPower) {
maxPower = p;
}
}
return (maxPower <= 1_000_000) ? smallPower(power, maxPower) : bigPower(power);
}
private long smallPower(int[] power, int maxPower) {
int[] counts = new int[maxPower + 6];
for (int p : power) {
counts[p]++;
}
long[] dp = new long[maxPower + 6];
dp[1] = counts[1];
dp[2] = Math.max(counts[2] * 2L, dp[1]);
for (int i = 3; i <= maxPower; i++) {
dp[i] = Math.max((long) counts[i] * i + dp[i - 3], Math.max(dp[i - 1], dp[i - 2]));
}
return dp[maxPower];
}
private long bigPower(int[] power) {
Arrays.sort(power);
int n = power.length;
long[] prevs = new long[4];
int curPower = power[0];
int count = 1;
long result = 0;
for (int i = 1; i <= n; i++) {
int p = (i == n) ? 1_000_000_009 : power[i];
if (p == curPower) {
count++;
} else {
long curVal =
Math.max((long) curPower * count + prevs[3], Math.max(prevs[1], prevs[2]));
int diff = Math.min(p - curPower, prevs.length - 1);
long nextCurVal = (diff == 1) ? 0 : Math.max(prevs[3], Math.max(curVal, prevs[2]));
// Shift the values in prevs[].
int k = prevs.length - 1;
if (diff < prevs.length - 1) {
while (k > diff) {
prevs[k] = prevs[k-- - diff];
}
prevs[k--] = curVal;
}
while (k > 0) {
prevs[k--] = nextCurVal;
}
curPower = p;
count = 1;
}
}
for (long v : prevs) {
if (v > result) {
result = v;
}
}
return result;
}
}