//====================================================
// best version. clearer and eleganter
// return value is the max sum from an internal node to root
// 7 lines of code
int get_maxsum_till_root(TreeNode *root, int &final_max) {
if(!root) return 0;
int L = get_maxsum_till_root(root->left, final_max);
int R = get_maxsum_till_root(root->right, final_max);
int maxsum_pass_root = L + R + root->val;
int maxsum_till_root = max(root->val, max(L, R) + root->val);
final_max = max(final_max, max(maxsum_pass_root, maxsum_till_root));
return maxsum_till_root;
}
//====================================================
// good concise solution
int dfs(TreeNode *root, int &final_max) {
if(!root) return 0;
int l = dfs(root->left, final_max);
int r = dfs(root->right, final_max);
int sum = root->val;
if(l > 0) sum += l;
if(r > 0) sum += r;
final_max = max(final_max, sum);
return max(l, r) > 0 ? max(l, r) + root->val : root->val;
}
//====================================================
// my initial solution, not sure if correct. large test time out.
void get_max_sum_from_root(TreeNode *root, int &max_sum) {
if(!root) return;
int left_max_sum = 0, right_max_sum = 0; // =0
get_max_sum_from_root(root->left, left_max_sum);
get_max_sum_from_root(root->right, right_max_sum);
max_sum = max(root->val, max(left_max_sum, right_max_sum) + root->val);
}
void binary_tree_max_path_sum(TreeNode *root, int &max_sum) {
if(!root) return;
int left_max_sum = 0, right_max_sum = 0;
get_max_sum_from_root(root->left, left_max_sum);
get_max_sum_from_root(root->right, right_max_sum);
int max_sum_path_pass_root = max(left_max_sum, right_max_sum) + root->val;
int larger = max(root->val, max_sum_path_pass_root);
int L = 0, R = 0;
binary_tree_max_path_sum(root->left, L);
binary_tree_max_path_sum(root->right, R);
max_sum = max(max(L, R), larger);
}