Search Options

Results per page
Sort
Preferred Languages
Advance

Results 1 - 9 of 9 for GradientRegistry (0.24 sec)

  1. tensorflow/c/experimental/gradients/not_differentiable.cc

        absl::Span<AbstractTensorHandle*> grad_inputs) {
      for (int i = 0; i < grad_inputs.size(); i++) {
        grad_inputs[i] = nullptr;
      }
      return OkStatus();
    }
    
    Status RegisterNotDifferentiable(GradientRegistry* registry, const string& op) {
      return registry->Register(op, [](const ForwardOperation& op) {
        return new NotDifferentiableGradientFunction;
      });
    }
    }  // namespace gradients
    C++
    - Registered: Tue Feb 27 12:39:08 GMT 2024
    - Last Modified: Wed Jun 15 01:15:58 GMT 2022
    - 1.3K bytes
    - Viewed (0)
  2. tensorflow/c/experimental/gradients/grad_test_helper.cc

        } else {
          ASSERT_NEAR(manuals[j], danalytical[j], abs_error);
        }
      }
    
      TF_DeleteTensor(analytical_tensor);
      delete[] danalytical;
    }
    
    Model BuildGradModel(Model forward, GradientRegistry registry) {
      return [forward_model = std::move(forward),
              grad_registry = std::move(registry)](
                 AbstractContext* ctx,
                 absl::Span<AbstractTensorHandle* const> inputs,
    C++
    - Registered: Tue Mar 26 12:39:09 GMT 2024
    - Last Modified: Wed Feb 28 13:53:47 GMT 2024
    - 5K bytes
    - Viewed (0)
  3. tensorflow/c/eager/gradients.cc

      TF_RETURN_IF_ERROR(
          op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs));
      *result = outputs[0];
      return absl::OkStatus();
    }
    }  // namespace
    
    Status GradientRegistry::Register(
        const string& op_name, GradientFunctionFactory gradient_function_factory) {
      auto iter = registry_.find(op_name);
      if (iter != registry_.end()) {
    C++
    - Registered: Tue Apr 30 12:39:09 GMT 2024
    - Last Modified: Thu Feb 15 09:49:45 GMT 2024
    - 19.3K bytes
    - Viewed (0)
  4. tensorflow/c/experimental/gradients/array_grad_test.cc

        // unstable. Some forward pass tests also fail with TensorFloat-32 due to
        // low tolerances
        enable_tensor_float_32_execution(false);
      }
    
      AbstractContextPtr immediate_execution_ctx_;
      GradientRegistry registry_;
      Status status_;
    
     public:
      bool UseMlir() const { return strcmp(std::get<0>(GetParam()), "mlir") == 0; }
      bool UseFunction() const { return std::get<2>(GetParam()); }
    };
    
    C++
    - Registered: Tue Mar 26 12:39:09 GMT 2024
    - Last Modified: Wed Feb 28 13:53:47 GMT 2024
    - 5K bytes
    - Viewed (0)
  5. tensorflow/c/experimental/gradients/math_grad_test.cc

        // unstable. Some forward pass tests also fail with TensorFloat-32 due to
        // low tolerances
        enable_tensor_float_32_execution(false);
      }
    
      AbstractContextPtr immediate_execution_ctx_;
      GradientRegistry registry_;
      Status status_;
    
     public:
      bool UseMlir() const { return strcmp(std::get<0>(GetParam()), "mlir") == 0; }
      bool UseFunction() const { return std::get<2>(GetParam()); }
    };
    
    C++
    - Registered: Tue Mar 26 12:39:09 GMT 2024
    - Last Modified: Thu Apr 13 17:32:14 GMT 2023
    - 16.3K bytes
    - Viewed (0)
  6. tensorflow/c/eager/gradients_test.cc

        TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
        Status s = StatusFromTF_Status(status.get());
        CHECK_EQ(errors::OK, s.code()) << s.message();
      }
    };
    
    Status RegisterGradients(GradientRegistry* registry) {
      TF_RETURN_IF_ERROR(RegisterNotDifferentiable(registry, "CheckNumerics"));
      return absl::OkStatus();
    }
    
    TEST_P(CppGradients, TestSetAttrString) {
    C++
    - Registered: Tue Apr 30 12:39:09 GMT 2024
    - Last Modified: Thu Feb 15 09:49:45 GMT 2024
    - 7K bytes
    - Viewed (0)
  7. tensorflow/c/experimental/gradients/tape/tape_context.cc

    #include "tensorflow/c/experimental/gradients/tape/tape_operation.h"
    
    namespace tensorflow {
    namespace gradients {
    TapeContext::TapeContext(AbstractContext* c, Tape* tape,
                             const GradientRegistry& registry)
        : AbstractContext(kTape), parent_ctx_(c), tape_(tape), registry_(registry) {
      // TODO(srbs): Make AbstractContext ref counted.
      // parent_ctx_->Ref();
    }
    void TapeContext::Release() {
    C++
    - Registered: Tue Feb 27 12:39:08 GMT 2024
    - Last Modified: Wed Sep 23 23:12:39 GMT 2020
    - 1.7K bytes
    - Viewed (0)
  8. tensorflow/c/experimental/gradients/tape/tape_operation.cc

    #include "tensorflow/c/eager/gradients.h"
    
    namespace tensorflow {
    namespace gradients {
    TapeOperation::TapeOperation(AbstractOperation* parent_op, Tape* tape,
                                 const GradientRegistry& registry)
        : AbstractOperation(kTape),
          parent_op_(parent_op),
          tape_(tape),
          registry_(registry) {
      // TODO(b/172003047): Consider making AbstractOperation RefCounted.
      // parent_op_->Ref();
    C++
    - Registered: Tue Feb 27 12:39:08 GMT 2024
    - Last Modified: Tue Jun 07 01:53:35 GMT 2022
    - 9K bytes
    - Viewed (1)
  9. tensorflow/c/experimental/gradients/nn_grad_test.cc

        // unstable. Some forward pass tests also fail with TensorFloat-32 due to
        // low tolerances
        enable_tensor_float_32_execution(false);
      }
    
      AbstractContextPtr immediate_execution_ctx_;
      GradientRegistry registry_;
      Status status_;
    
     public:
      bool UseMlir() const { return strcmp(std::get<0>(GetParam()), "mlir") == 0; }
      bool UseFunction() const { return std::get<2>(GetParam()); }
    };
    
    C++
    - Registered: Tue Mar 26 12:39:09 GMT 2024
    - Last Modified: Wed Feb 28 13:53:47 GMT 2024
    - 8.3K bytes
    - Viewed (0)
Back to top