Fix merge conflicts with upstream
[akkoma] / lib / pleroma / web / o_auth / o_auth_controller.ex
index d2f9d1cebf379c941df3dfffe46d6b2f19215ac9..73d48549710fcf014e50cfc9008ef2738fa453fd 100644 (file)
@@ -1,10 +1,11 @@
 # Pleroma: A lightweight social networking server
-# Copyright © 2017-2020 Pleroma Authors <https://pleroma.social/>
+# Copyright © 2017-2021 Pleroma Authors <https://pleroma.social/>
 # SPDX-License-Identifier: AGPL-3.0-only
 
 defmodule Pleroma.Web.OAuth.OAuthController do
   use Pleroma.Web, :controller
 
+  alias Pleroma.Helpers.AuthHelper
   alias Pleroma.Helpers.UriHelper
   alias Pleroma.Maps
   alias Pleroma.MFA
@@ -79,6 +80,13 @@ defmodule Pleroma.Web.OAuth.OAuthController do
     available_scopes = (app && app.scopes) || []
     scopes = Scopes.fetch_scopes(params, available_scopes)
 
+    user =
+      with %{assigns: %{user: %User{} = user}} <- conn do
+        user
+      else
+        _ -> nil
+      end
+
     scopes =
       if scopes == [] do
         available_scopes
@@ -88,6 +96,8 @@ defmodule Pleroma.Web.OAuth.OAuthController do
 
     # Note: `params` might differ from `conn.params`; use `@params` not `@conn.params` in template
     render(conn, Authenticator.auth_template(), %{
+      user: user,
+      app: app && Map.delete(app, :client_secret),
       response_type: params["response_type"],
       client_id: params["client_id"],
       available_scopes: available_scopes,
@@ -131,11 +141,13 @@ defmodule Pleroma.Web.OAuth.OAuthController do
     end
   end
 
-  def create_authorization(
-        %Plug.Conn{} = conn,
-        %{"authorization" => _} = params,
-        opts \\ []
-      ) do
+  def create_authorization(_, _, opts \\ [])
+
+  def create_authorization(%Plug.Conn{assigns: %{user: %User{} = user}} = conn, params, []) do
+    create_authorization(conn, params, user: user)
+  end
+
+  def create_authorization(%Plug.Conn{} = conn, %{"authorization" => _} = params, opts) do
     with {:ok, auth, user} <- do_create_authorization(conn, params, opts[:user]),
          {:mfa_required, _, _, false} <- {:mfa_required, user, auth, MFA.require?(user)} do
       after_create_authorization(conn, auth, params)
@@ -248,7 +260,7 @@ defmodule Pleroma.Web.OAuth.OAuthController do
     with {:ok, app} <- Token.Utils.fetch_app(conn),
          {:ok, %{user: user} = token} <- Token.get_by_refresh_token(app, token),
          {:ok, token} <- RefreshToken.grant(token) do
-      json(conn, OAuthView.render("token.json", %{user: user, token: token}))
+      after_token_exchange(conn, %{user: user, token: token})
     else
       _error -> render_invalid_credentials_error(conn)
     end
@@ -260,7 +272,7 @@ defmodule Pleroma.Web.OAuth.OAuthController do
          {:ok, auth} <- Authorization.get_by_token(app, fixed_token),
          %User{} = user <- User.get_cached_by_id(auth.user_id),
          {:ok, token} <- Token.exchange_token(app, auth) do
-      json(conn, OAuthView.render("token.json", %{user: user, token: token}))
+      after_token_exchange(conn, %{user: user, token: token})
     else
       error ->
         handle_token_exchange_error(conn, error)
@@ -275,7 +287,7 @@ defmodule Pleroma.Web.OAuth.OAuthController do
          {:ok, app} <- Token.Utils.fetch_app(conn),
          requested_scopes <- Scopes.fetch_scopes(params, app.scopes),
          {:ok, token} <- login(user, app, requested_scopes) do
-      json(conn, OAuthView.render("token.json", %{user: user, token: token}))
+      after_token_exchange(conn, %{user: user, token: token})
     else
       error ->
         handle_token_exchange_error(conn, error)
@@ -298,7 +310,7 @@ defmodule Pleroma.Web.OAuth.OAuthController do
     with {:ok, app} <- Token.Utils.fetch_app(conn),
          {:ok, auth} <- Authorization.create_authorization(app, %User{}),
          {:ok, token} <- Token.exchange_token(app, auth) do
-      json(conn, OAuthView.render("token.json", %{token: token}))
+      after_token_exchange(conn, %{token: token})
     else
       _error ->
         handle_token_exchange_error(conn, :invalid_credentails)
@@ -308,6 +320,12 @@ defmodule Pleroma.Web.OAuth.OAuthController do
   # Bad request
   def token_exchange(%Plug.Conn{} = conn, params), do: bad_request(conn, params)
 
+  def after_token_exchange(%Plug.Conn{} = conn, %{token: token} = view_params) do
+    conn
+    |> AuthHelper.put_session_token(token.token)
+    |> json(OAuthView.render("token.json", view_params))
+  end
+
   defp handle_token_exchange_error(%Plug.Conn{} = conn, {:mfa_required, user, auth, _}) do
     conn
     |> put_status(:forbidden)
@@ -361,9 +379,17 @@ defmodule Pleroma.Web.OAuth.OAuthController do
     render_invalid_credentials_error(conn)
   end
 
-  def token_revoke(%Plug.Conn{} = conn, %{"token" => _token} = params) do
-    with {:ok, app} <- Token.Utils.fetch_app(conn),
-         {:ok, _token} <- RevokeToken.revoke(app, params) do
+  def token_revoke(%Plug.Conn{} = conn, %{"token" => token}) do
+    with {:ok, %Token{} = oauth_token} <- Token.get_by_token(token),
+         {:ok, oauth_token} <- RevokeToken.revoke(oauth_token) do
+      conn =
+        with session_token = AuthHelper.get_session_token(conn),
+             %Token{token: ^session_token} <- oauth_token do
+          AuthHelper.delete_session_token(conn)
+        else
+          _ -> conn
+        end
+
       json(conn, %{})
     else
       _error ->
@@ -401,7 +427,7 @@ defmodule Pleroma.Web.OAuth.OAuthController do
       |> Map.put("state", state)
 
     # Handing the request to Ueberauth
-    redirect(conn, to: o_auth_path(conn, :request, provider, params))
+    redirect(conn, to: Routes.o_auth_path(conn, :request, provider, params))
   end
 
   def request(%Plug.Conn{} = conn, params) do
@@ -574,9 +600,12 @@ defmodule Pleroma.Web.OAuth.OAuthController do
     end
   end
 
+<<<<<<< HEAD
+=======
   # Special case: Local MastodonFE
-  defp redirect_uri(%Plug.Conn{} = conn, "."), do: auth_url(conn, :login)
+  defp redirect_uri(%Plug.Conn{} = conn, "."), do: Routes.auth_url(conn, :login)
 
+>>>>>>> 0c56f9de0d607b88fd107e0bd13ef286f0629346
   defp redirect_uri(%Plug.Conn{}, redirect_uri), do: redirect_uri
 
   defp get_session_registration_id(%Plug.Conn{} = conn), do: get_session(conn, :registration_id)